# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import _pickle
import abc
import collections
import contextlib
import functools
import os
import queue
import sys
import threading
import time
import typing
import warnings
from collections import defaultdict, OrderedDict
from collections.abc import Callable, Iterator, Mapping, Sequence
from copy import deepcopy
from multiprocessing import connection, queues
from multiprocessing.managers import SyncManager
from queue import Empty
from textwrap import indent
from typing import Any, TypeVar

import numpy as np
import torch
import torch.nn as nn

from tensordict import LazyStackedTensorDict, TensorDict, TensorDictBase
from tensordict.base import NO_DEFAULT
from tensordict.nn import CudaGraphModule, TensorDictModule, TensorDictModuleBase
from tensordict.utils import _zip_strict, Buffer
from torch import multiprocessing as mp
from torch.nn import Parameter
from torch.utils.data import IterableDataset

from torchrl._utils import (
    _check_for_faulty_process,
    _ends_with,
    _make_ordinal_device,
    _ProcessNoWarn,
    _replace_last,
    accept_remote_rref_udf_invocation,
    compile_with_warmup,
    logger as torchrl_logger,
    prod,
    rl_warnings,
    VERBOSE,
)
from torchrl.collectors.utils import split_trajectories
from torchrl.collectors.weight_update import WeightUpdaterBase
from torchrl.data import ReplayBuffer
from torchrl.data.utils import CloudpickleWrapper, DEVICE_TYPING
from torchrl.envs.common import _do_nothing, EnvBase
from torchrl.envs.env_creator import EnvCreator

from torchrl.envs.llm.transforms.policy_version import PolicyVersion
from torchrl.envs.transforms import StepCounter, TransformedEnv
from torchrl.envs.utils import (
    _aggregate_end_of_traj,
    _make_compatible_policy,
    ExplorationType,
    RandomPolicy,
    set_exploration_type,
)
from torchrl.weight_update import SharedMemWeightSyncScheme
from torchrl.weight_update.weight_sync_schemes import (
    _resolve_model,
    MultiProcessWeightSyncScheme,
    WeightReceiver,
    WeightSender,
    WeightSyncScheme,
)

try:
    from torch.compiler import cudagraph_mark_step_begin
except ImportError:

    def cudagraph_mark_step_begin():
        """Placeholder for missing cudagraph_mark_step_begin method."""
        raise NotImplementedError("cudagraph_mark_step_begin not implemented.")


_TIMEOUT = 1.0
INSTANTIATE_TIMEOUT = 20
_MIN_TIMEOUT = 1e-3  # should be several orders of magnitude inferior wrt time spent collecting a trajectory
# MAX_IDLE_COUNT is the maximum number of times a Dataloader worker can timeout with his queue.
_MAX_IDLE_COUNT = int(os.environ.get("MAX_IDLE_COUNT", torch.iinfo(torch.int64).max))

DEFAULT_EXPLORATION_TYPE: ExplorationType = ExplorationType.RANDOM

_is_osx = sys.platform.startswith("darwin")

T = TypeVar("T")


class _Interruptor:
    """A class for managing the collection state of a process.

    This class provides methods to start and stop collection, and to check
    whether collection has been stopped. The collection state is protected
    by a lock to ensure thread-safety.
    """

    # interrupter vs interruptor: google trends seems to indicate that "or" is more
    # widely used than "er" even if my IDE complains about that...
    def __init__(self):
        self._collect = True
        self._lock = mp.Lock()

    def start_collection(self):
        with self._lock:
            self._collect = True

    def stop_collection(self):
        with self._lock:
            self._collect = False

    def collection_stopped(self):
        with self._lock:
            return self._collect is False


class _InterruptorManager(SyncManager):
    """A custom SyncManager for managing the collection state of a process.

    This class extends the SyncManager class and allows to share an Interruptor object
    between processes.
    """


_InterruptorManager.register("_Interruptor", _Interruptor)


def recursive_map_to_cpu(dictionary: OrderedDict) -> OrderedDict:
    """Maps the tensors to CPU through a nested dictionary."""
    return OrderedDict(
        **{
            k: recursive_map_to_cpu(item)
            if isinstance(item, OrderedDict)
            else item.cpu()
            if isinstance(item, torch.Tensor)
            else item
            for k, item in dictionary.items()
        }
    )


class DataCollectorBase(IterableDataset, metaclass=abc.ABCMeta):
    """Base class for data collectors."""

    _task = None
    _iterator = None
    total_frames: int
    requested_frames_per_batch: int
    frames_per_batch: int
    trust_policy: bool
    compiled_policy: bool
    cudagraphed_policy: bool
    _weight_updater: WeightUpdaterBase | None = None
    _weight_sync_schemes: dict[str, WeightSyncScheme] | None = None
    _weight_senders: dict[str, WeightSender] | None = None
    _weight_receivers: dict[str, WeightReceiver] | None = None
    verbose: bool = False

    @property
    def weight_updater(self) -> WeightUpdaterBase:
        return self._weight_updater

    @weight_updater.setter
    def weight_updater(self, value: WeightUpdaterBase | None):
        if value is not None:
            if not isinstance(value, WeightUpdaterBase) and callable(
                value
            ):  # Fall back to default constructor
                value = value()
            value.register_collector(self)
            if value.collector is not self:
                raise RuntimeError("Failed to register collector.")
        self._weight_updater = value

    def _get_policy_and_device(
        self,
        policy: Callable[[Any], Any] | None = None,
        policy_device: Any = NO_DEFAULT,
        env_maker: Any | None = None,
        env_maker_kwargs: dict[str, Any] | None = None,
    ) -> tuple[TensorDictModule, None | Callable[[], dict]]:
        """Util method to get a policy and its device given the collector __init__ inputs.

        We want to copy the policy and then move the data there, not call policy.to(device).

        Args:
            policy (TensorDictModule, optional): a policy to be used
            policy_device (torch.device, optional): the device where the policy should be placed.
                Defaults to self.policy_device
            env_maker (a callable or a batched env, optional): the env_maker function for this device/policy pair.
            env_maker_kwargs (a dict, optional): the env_maker function kwargs.

        """
        if policy_device is NO_DEFAULT:
            policy_device = self.policy_device

        if not policy_device:
            return policy, None

        if isinstance(policy, nn.Module):
            param_and_buf = TensorDict.from_module(policy, as_module=True)
        else:
            # Because we want to reach the warning
            param_and_buf = TensorDict()

        i = -1
        for p in param_and_buf.values(True, True):
            i += 1
            if p.device != policy_device:
                # Then we need casting
                break
        else:
            if i == -1 and not self.trust_policy:
                # We trust that the policy policy device is adequate
                warnings.warn(
                    "A policy device was provided but no parameter/buffer could be found in "
                    "the policy. Casting to policy_device is therefore impossible. "
                    "The collector will trust that the devices match. To suppress this "
                    "warning, set `trust_policy=True` when building the collector."
                )
            return policy, None

        # Create a stateless policy, then populate this copy with params on device
        def get_original_weights(policy=policy):
            td = TensorDict.from_module(policy)
            return td.data

        # We need to use ".data" otherwise buffers may disappear from the `get_original_weights` function
        with param_and_buf.data.to("meta").to_module(policy):
            policy_new_device = deepcopy(policy)

        param_and_buf_new_device = param_and_buf.apply(
            functools.partial(_map_weight, policy_device=policy_device),
            filter_empty=False,
        )
        param_and_buf_new_device.to_module(policy_new_device)
        # Sanity check
        if set(TensorDict.from_module(policy_new_device).keys(True, True)) != set(
            get_original_weights().keys(True, True)
        ):
            raise RuntimeError("Failed to map weights. The weight sets mismatch.")
        return policy_new_device, get_original_weights

    def start(self):
        """Starts the collector for asynchronous data collection.

        This method initiates the background collection of data, allowing for decoupling of data collection and training.

        The collected data is typically stored in a replay buffer passed during the collector's initialization.

        .. note:: After calling this method, it's essential to shut down the collector using :meth:`~.async_shutdown`
            when you're done with it to free up resources.

        .. warning:: Asynchronous data collection can significantly impact training performance due to its decoupled nature.
            Ensure you understand the implications for your specific algorithm before using this mode.

        Raises:
            NotImplementedError: If not implemented by a subclass.
        """
        raise NotImplementedError(
            f"Collector start() is not implemented for {type(self).__name__}."
        )

    @contextlib.contextmanager
    def pause(self):
        """Context manager that pauses the collector if it is running free."""
        raise NotImplementedError(
            f"Collector pause() is not implemented for {type(self).__name__}."
        )

    def async_shutdown(
        self, timeout: float | None = None, close_env: bool = True
    ) -> None:
        """Shuts down the collector when started asynchronously with the `start` method.

        Args:
            timeout (float, optional): The maximum time to wait for the collector to shutdown.
            close_env (bool, optional): If True, the collector will close the contained environment.
                Defaults to `True`.

        .. seealso:: :meth:`~.start`

        """
        return self.shutdown(timeout=timeout, close_env=close_env)

    def _extract_weights_if_needed(self, weights: Any, model_id: str) -> Any:
        """Extract weights from a model if needed.

        For the new weight sync scheme system, weight preparation is handled
        by the scheme's prepare_weights() method. This method now only handles
        legacy weight updater cases.

        Args:
            weights: Either already-extracted weights or a model to extract from.
            model_id: The model identifier for resolving string paths.

        Returns:
            Extracted weights in the appropriate format.
        """
        # New weight sync schemes handle preparation themselves
        if self._weight_sync_schemes:
            # Just pass through - WeightSender will call scheme.prepare_weights()
            return weights

        # Legacy weight updater path
        return self._legacy_extract_weights(weights, model_id)

    def _legacy_extract_weights(self, weights: Any, model_id: str) -> Any:
        """Legacy weight extraction for old weight updater system.

        Args:
            weights: Either already-extracted weights or a model to extract from.
            model_id: The model identifier.

        Returns:
            Extracted weights.
        """
        if weights is None:
            if model_id == "policy" and hasattr(self, "policy_weights"):
                return self.policy_weights
            elif model_id == "policy" and hasattr(self, "_policy_weights_dict"):
                policy_device = (
                    self.policy_device
                    if not isinstance(self.policy_device, (list, tuple))
                    else self.policy_device[0]
                )
                return self._policy_weights_dict.get(policy_device)
            return None

        return weights

    @property
    def _legacy_weight_updater(self) -> bool:
        return self._weight_updater is not None

    def update_policy_weights_(
        self,
        policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None,
        *,
        worker_ids: int | list[int] | torch.device | list[torch.device] | None = None,
        model_id: str | None = None,
        weights_dict: dict[str, Any] | None = None,
        **kwargs,
    ) -> None:
        """Updates the policy weights for the data collector, accommodating both local and remote execution contexts.

        This method ensures that the policy weights used by the data collector are synchronized with the latest
        trained weights. It supports both local and remote weight updates, depending on the configuration of the
        data collector. The local (download) update is performed before the remote (upload) update, such that weights
        can be transferred to the children workers from a server.

        Args:
            policy_or_weights (TensorDictBase | TensorDictModuleBase | dict | None): The weights to update with. Can be:
                - TensorDictModuleBase: A policy module whose weights will be extracted
                - TensorDictBase: A TensorDict containing weights
                - dict: A regular dict containing weights
                - None: Will try to get weights from server using _get_server_weights()
            worker_ids (int | List[int] | torch.device | List[torch.device] | None, optional): Identifiers for the
                workers that need to be updated. This is relevant when the collector has more than one worker associated
                with it.
            model_id (str | None, optional): The model identifier to update. If provided, only updates this specific
                model. Cannot be used together with weights_dict.
            weights_dict (dict[str, Any] | None, optional): Dictionary mapping model_id to weights for updating
                multiple models atomically. Keys should match the model_ids registered in weight_sync_schemes.
                Cannot be used together with model_id or policy_or_weights.

        Raises:
            TypeError: If `worker_ids` is provided but no `weight_updater` is configured.
            ValueError: If conflicting parameters are provided (e.g., both model_id and weights_dict).

        .. note:: Users should extend the `WeightUpdaterBase` classes to customize
            the weight update logic for specific use cases. This method should not be overwritten.

        .. seealso:: :class:`~torchrl.collectors.LocalWeightsUpdaterBase` and
            :meth:`~torchrl.collectors.RemoteWeightsUpdaterBase`.

        """
        if self._legacy_weight_updater:
            return self._legacy_weight_update_impl(
                policy_or_weights=policy_or_weights,
                worker_ids=worker_ids,
                model_id=model_id,
                weights_dict=weights_dict,
                **kwargs,
            )
        else:
            return self._weight_update_impl(
                policy_or_weights=policy_or_weights,
                worker_ids=worker_ids,
                model_id=model_id,
                weights_dict=weights_dict,
                **kwargs,
            )

    def _legacy_weight_update_impl(
        self,
        policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None,
        *,
        worker_ids: int | list[int] | torch.device | list[torch.device] | None = None,
        model_id: str | None = None,
        weights_dict: dict[str, Any] | None = None,
        **kwargs,
    ) -> None:
        if weights_dict is not None:
            raise ValueError("weights_dict is not supported with legacy weight updater")
        if model_id is not None:
            raise ValueError("model_id is not supported with legacy weight updater")
        # Fall back to old weight updater system
        self.weight_updater(
            policy_or_weights=policy_or_weights, worker_ids=worker_ids, **kwargs
        )

    def _weight_update_impl(
        self,
        policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None,
        *,
        worker_ids: int | list[int] | torch.device | list[torch.device] | None = None,
        model_id: str | None = None,
        weights_dict: dict[str, Any] | None = None,
        **kwargs,
    ) -> None:
        if "policy_weights" in kwargs:
            warnings.warn(
                "`policy_weights` is deprecated. Use `policy_or_weights` instead.",
                DeprecationWarning,
            )
            policy_or_weights = kwargs.pop("policy_weights")

        if weights_dict is not None and model_id is not None:
            raise ValueError("Cannot specify both 'weights_dict' and 'model_id'")

        if weights_dict is not None and policy_or_weights is not None:
            raise ValueError(
                "Cannot specify both 'weights_dict' and 'policy_or_weights'"
            )

        if policy_or_weights is not None:
            weights_dict = {"policy": policy_or_weights}

        # Priority: new weight sync schemes > old weight updater system
        if self._weight_senders:
            if model_id is not None:
                # Compose weight_dict
                weights_dict = {model_id: policy_or_weights}
            if weights_dict is None:
                if "policy" in self._weight_senders:
                    weights_dict = {"policy": policy_or_weights}
                elif len(self._weight_senders) == 1:
                    single_model_id = next(iter(self._weight_senders.keys()))
                    weights_dict = {single_model_id: policy_or_weights}
                else:
                    raise ValueError(
                        "Cannot determine the model to update. Please provide a weights_dict."
                    )
            for target_model_id, weights in weights_dict.items():
                if target_model_id not in self._weight_senders:
                    raise KeyError(
                        f"Model '{target_model_id}' not found in registered weight senders. "
                        f"Available models: {list(self._weight_senders.keys())}"
                    )
                processed_weights = self._extract_weights_if_needed(
                    weights, target_model_id
                )
                # Use new send() API with worker_ids support
                self._weight_senders[target_model_id].send(
                    weights=processed_weights, worker_ids=worker_ids
                )
        elif self._weight_updater is not None:
            # unreachable
            raise RuntimeError
        else:
            return self.receive_weights(policy_or_weights)

    def receive_weights(self, policy_or_weights: TensorDictBase | None = None):
        # No weight updater configured
        # For single-process collectors, apply weights locally if explicitly provided
        if policy_or_weights is not None:
            from torchrl.weight_update.weight_sync_schemes import WeightStrategy

            # Use WeightStrategy to apply weights properly
            strategy = WeightStrategy(extract_as="tensordict")

            # Extract weights if needed
            if isinstance(policy_or_weights, nn.Module):
                weights = strategy.extract_weights(policy_or_weights)
            else:
                weights = policy_or_weights

            # Apply to local policy
            if hasattr(self, "policy") and isinstance(self.policy, nn.Module):
                strategy.apply_weights(self.policy, weights)
        elif (
            hasattr(self, "_original_policy")
            and isinstance(self._original_policy, nn.Module)
            and hasattr(self, "policy")
            and isinstance(self.policy, nn.Module)
        ):
            # If no weights were provided, mirror weights from the original (trainer) policy
            from torchrl.weight_update.weight_sync_schemes import WeightStrategy

            strategy = WeightStrategy(extract_as="tensordict")
            weights = strategy.extract_weights(self._original_policy)
            # Cast weights to the policy device before applying
            if self.policy_device is not None:
                weights = weights.to(self.policy_device)
            strategy.apply_weights(self.policy, weights)
        # Otherwise, no action needed - policy is local and changes are immediately visible

    def __iter__(self) -> Iterator[TensorDictBase]:
        try:
            yield from self.iterator()
        except Exception:
            self.shutdown()
            raise

    def next(self):
        try:
            if self._iterator is None:
                self._iterator = iter(self)
            out = next(self._iterator)
            # if any, we don't want the device ref to be passed in distributed settings
            if out is not None and (out.device != "cpu"):
                out = out.copy().clear_device_()
            return out
        except StopIteration:
            return None

    @abc.abstractmethod
    def shutdown(
        self,
        timeout: float | None = None,
        close_env: bool = True,
        raise_on_error: bool = True,
    ) -> None:
        raise NotImplementedError

    @abc.abstractmethod
    def iterator(self) -> Iterator[TensorDictBase]:
        raise NotImplementedError

    @abc.abstractmethod
    def set_seed(self, seed: int, static_seed: bool = False) -> int:
        raise NotImplementedError

    @abc.abstractmethod
    def state_dict(self) -> OrderedDict:
        raise NotImplementedError

    @abc.abstractmethod
    def load_state_dict(self, state_dict: OrderedDict) -> None:
        raise NotImplementedError

    def _read_compile_kwargs(self, compile_policy, cudagraph_policy):
        self.compiled_policy = compile_policy not in (False, None)
        self.cudagraphed_policy = cudagraph_policy not in (False, None)
        self.compiled_policy_kwargs = (
            {} if not isinstance(compile_policy, typing.Mapping) else compile_policy
        )
        self.cudagraphed_policy_kwargs = (
            {} if not isinstance(cudagraph_policy, typing.Mapping) else cudagraph_policy
        )

    def __repr__(self) -> str:
        string = f"{self.__class__.__name__}()"
        return string

    def __class_getitem__(self, index):
        raise NotImplementedError

    def __len__(self) -> int:
        if self.total_frames > 0:
            return -(self.total_frames // -self.requested_frames_per_batch)
        raise RuntimeError("Non-terminating collectors do not have a length")

    def init_updater(self, *args, **kwargs):
        """Initialize the weight updater with custom arguments.

        This method passes the arguments to the weight updater's init method.
        If no weight updater is set, this is a no-op.

        Args:
            *args: Positional arguments for weight updater initialization
            **kwargs: Keyword arguments for weight updater initialization
        """
        if self.weight_updater is not None:
            self.weight_updater.init(*args, **kwargs)


@accept_remote_rref_udf_invocation
class SyncDataCollector(DataCollectorBase):
    """Generic data collector for RL problems. Requires an environment constructor and a policy.

    Args:
        create_env_fn (Callable or EnvBase): a callable that returns an instance of
            :class:`~torchrl.envs.EnvBase` class, or the env itself.
        policy (Callable): Policy to be executed in the environment.
            Must accept :class:`tensordict.tensordict.TensorDictBase` object as input.
            If ``None`` is provided, the policy used will be a
            :class:`~torchrl.collectors.RandomPolicy` instance with the environment
            ``action_spec``.
            Accepted policies are usually subclasses of :class:`~tensordict.nn.TensorDictModuleBase`.
            This is the recommended usage of the collector.
            Other callables are accepted too:
            If the policy is not a ``TensorDictModuleBase`` (e.g., a regular :class:`~torch.nn.Module`
            instances) it will be wrapped in a `nn.Module` first.
            Then, the collector will try to assess if these
            modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not.

            - If the policy forward signature matches any of ``forward(self, tensordict)``,
              ``forward(self, td)`` or ``forward(self, <anything>: TensorDictBase)`` (or
              any typing with a single argument typed as a subclass of ``TensorDictBase``)
              then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`.

            - In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.

            .. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized /
                pickled directly), the ``policy_factory`` should be used instead.

    Keyword Args:
        policy_factory (Callable[[], Callable], optional): a callable that returns
            a policy instance. This is exclusive with the `policy` argument.

            .. note:: `policy_factory` comes in handy whenever the policy cannot be serialized.

        frames_per_batch (int): A keyword-only argument representing the total
            number of elements in a batch.
        total_frames (int): A keyword-only argument representing the total
            number of frames returned by the collector
            during its lifespan. If the ``total_frames`` is not divisible by
            ``frames_per_batch``, an exception is raised.
             Endless collectors can be created by passing ``total_frames=-1``.
             Defaults to ``-1`` (endless collector).
        device (int, str or torch.device, optional): The generic device of the
            collector. The ``device`` args fills any non-specified device: if
            ``device`` is not ``None`` and any of ``storing_device``, ``policy_device`` or
            ``env_device`` is not specified, its value will be set to ``device``.
            Defaults to ``None`` (No default device).
        storing_device (int, str or torch.device, optional): The device on which
            the output :class:`~tensordict.TensorDict` will be stored.
            If ``device`` is passed and ``storing_device`` is ``None``, it will
            default to the value indicated by ``device``.
            For long trajectories, it may be necessary to store the data on a different
            device than the one where the policy and env are executed.
            Defaults to ``None`` (the output tensordict isn't on a specific device,
            leaf tensors sit on the device where they were created).
        env_device (int, str or torch.device, optional): The device on which
            the environment should be cast (or executed if that functionality is
            supported). If not specified and the env has a non-``None`` device,
            ``env_device`` will default to that value. If ``device`` is passed
            and ``env_device=None``, it will default to ``device``. If the value
            as such specified of ``env_device`` differs from ``policy_device``
            and one of them is not ``None``, the data will be cast to ``env_device``
            before being passed to the env (i.e., passing different devices to
            policy and env is supported). Defaults to ``None``.
        policy_device (int, str or torch.device, optional): The device on which
            the policy should be cast.
            If ``device`` is passed and ``policy_device=None``, it will default
            to ``device``. If the value as such specified of ``policy_device``
            differs from ``env_device`` and one of them is not ``None``,
            the data will be cast to ``policy_device`` before being passed to
            the policy (i.e., passing different devices to policy and env is
            supported). Defaults to ``None``.
        create_env_kwargs (dict, optional): Dictionary of kwargs for
            ``create_env_fn``.
        max_frames_per_traj (int, optional): Maximum steps per trajectory.
            Note that a trajectory can span across multiple batches (unless
            ``reset_at_each_iter`` is set to ``True``, see below).
            Once a trajectory reaches ``n_steps``, the environment is reset.
            If the environment wraps multiple environments together, the number
            of steps is tracked for each environment independently. Negative
            values are allowed, in which case this argument is ignored.
            Defaults to ``None`` (i.e., no maximum number of steps).
        init_random_frames (int, optional): Number of frames for which the
            policy is ignored before it is called. This feature is mainly
            intended to be used in offline/model-based settings, where a
            batch of random trajectories can be used to initialize training.
            If provided, it will be rounded up to the closest multiple of frames_per_batch.
            Defaults to ``None`` (i.e. no random frames).
        reset_at_each_iter (bool, optional): Whether environments should be reset
            at the beginning of a batch collection.
            Defaults to ``False``.
        postproc (Callable, optional): A post-processing transform, such as
            a :class:`~torchrl.envs.Transform` or a :class:`~torchrl.data.postprocs.MultiStep`
            instance.

            .. warning:: Postproc is not applied when a replay buffer is used and items are added to the buffer
                as they are produced (`extend_buffer=False`). The recommended usage is to use `extend_buffer=True`.

            Defaults to ``None``.
        split_trajs (bool, optional): Boolean indicating whether the resulting
            TensorDict should be split according to the trajectories.
            See :func:`~torchrl.collectors.utils.split_trajectories` for more
            information.
            Defaults to ``False``.
        exploration_type (ExplorationType, optional): interaction mode to be used when
            collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.DETERMINISTIC``,
            ``torchrl.envs.utils.ExplorationType.RANDOM``, ``torchrl.envs.utils.ExplorationType.MODE``
            or ``torchrl.envs.utils.ExplorationType.MEAN``.
        return_same_td (bool, optional): if ``True``, the same TensorDict
            will be returned at each iteration, with its values
            updated. This feature should be used cautiously: if the same
            tensordict is added to a replay buffer for instance,
            the whole content of the buffer will be identical.
            Default is ``False``.
        interruptor (_Interruptor, optional):
            An _Interruptor object that can be used from outside the class to control rollout collection.
            The _Interruptor class has methods ´start_collection´ and ´stop_collection´, which allow to implement
            strategies such as preeptively stopping rollout collection.
            Default is ``False``.
        set_truncated (bool, optional): if ``True``, the truncated signals (and corresponding
            ``"done"`` but not ``"terminated"``) will be set to ``True`` when the last frame of
            a rollout is reached. If no ``"truncated"`` key is found, an exception is raised.
            Truncated keys can be set through ``env.add_truncated_keys``.
            Defaults to ``False``.
        use_buffers (bool, optional): if ``True``, a buffer will be used to stack the data.
            This isn't compatible with environments with dynamic specs. Defaults to ``True``
            for envs without dynamic specs, ``False`` for others.
        replay_buffer (ReplayBuffer, optional): if provided, the collector will not yield tensordicts
            but populate the buffer instead.
            Defaults to ``None``.

            .. seealso:: By default (``extend_buffer=True``), the buffer is extended with entire rollouts.
                If the buffer needs to be populated with individual frames as they are collected,
                set ``extend_buffer=False`` (deprecated).

            .. warning:: Using a replay buffer with a `postproc` or `split_trajs=True` requires
                `extend_buffer=True`, as the whole batch needs to be observed to apply these transforms.

        extend_buffer (bool, optional): if `True`, the replay buffer is extended with entire rollouts and not
            with single steps. Defaults to `True`.

            .. note:: Setting this to `False` is deprecated and will be removed in a future version.
                Extending the buffer with entire rollouts is the recommended approach for better
                compatibility with postprocessing and trajectory splitting.
        trust_policy (bool, optional): if ``True``, a non-TensorDictModule policy will be trusted to be
            assumed to be compatible with the collector. This defaults to ``True`` for CudaGraphModules
            and ``False`` otherwise.
        compile_policy (bool or Dict[str, Any], optional): if ``True``, the policy will be compiled
            using :func:`~torch.compile` default behaviour. If a dictionary of kwargs is passed, it
            will be used to compile the policy.
        cudagraph_policy (bool or Dict[str, Any], optional): if ``True``, the policy will be wrapped
            in :class:`~tensordict.nn.CudaGraphModule` with default kwargs.
            If a dictionary of kwargs is passed, it will be used to wrap the policy.
        no_cuda_sync (bool): if ``True``, explicit CUDA synchronizations calls will be bypassed.
            For environments running directly on CUDA (`IsaacLab <https://github.com/isaac-sim/IsaacLab/>`_
            or `ManiSkills <https://github.com/haosulab/ManiSkill/>`_) cuda synchronization may cause unexpected
            crashes.
            Defaults to ``False``.
        weight_updater (WeightUpdaterBase or constructor, optional): An instance of :class:`~torchrl.collectors.WeightUpdaterBase`
            or its subclass, responsible for updating the policy weights on remote inference workers.
            This is typically not used in :class:`~torchrl.collectors.SyncDataCollector` as it operates in a single-process environment.
            Consider using a constructor if the updater needs to be serialized.
        track_policy_version (bool or PolicyVersion, optional): if ``True``, the collector will track the version of the policy.
            This will be mediated by the :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` transform, which will be added to the environment.
            Alternatively, a :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` instance can be passed, which will be used to track
            the policy version.
            Defaults to `False`.

    Examples:
        >>> from torchrl.envs.libs.gym import GymEnv
        >>> from tensordict.nn import TensorDictModule
        >>> from torch import nn
        >>> env_maker = lambda: GymEnv("Pendulum-v1", device="cpu")
        >>> policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"])
        >>> collector = SyncDataCollector(
        ...     create_env_fn=env_maker,
        ...     policy=policy,
        ...     total_frames=2000,
        ...     max_frames_per_traj=50,
        ...     frames_per_batch=200,
        ...     init_random_frames=-1,
        ...     reset_at_each_iter=False,
        ...     device="cpu",
        ...     storing_device="cpu",
        ... )
        >>> for i, data in enumerate(collector):
        ...     if i == 2:
        ...         print(data)
        ...         break
        TensorDict(
            fields={
                action: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                collector: TensorDict(
                    fields={
                        traj_ids: Tensor(shape=torch.Size([200]), device=cpu, dtype=torch.int64, is_shared=False)},
                    batch_size=torch.Size([200]),
                    device=cpu,
                    is_shared=False),
                done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                next: TensorDict(
                    fields={
                        done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                        observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                        reward: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                        step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False),
                        truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
                    batch_size=torch.Size([200]),
                    device=cpu,
                    is_shared=False),
                observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False),
                truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
            batch_size=torch.Size([200]),
            device=cpu,
            is_shared=False)
        >>> del collector

    The collector delivers batches of data that are marked with a ``"time"``
    dimension.

    Examples:
        >>> assert data.names[-1] == "time"

    """

    _ignore_rb: bool = False

    def __init__(
        self,
        create_env_fn: (
            EnvBase | EnvCreator | Sequence[Callable[[], EnvBase]]  # noqa: F821
        ),  # noqa: F821
        policy: None
        | (TensorDictModule | Callable[[TensorDictBase], TensorDictBase]) = None,
        *,
        policy_factory: Callable[[], Callable] | None = None,
        frames_per_batch: int,
        total_frames: int = -1,
        device: DEVICE_TYPING | None = None,
        storing_device: DEVICE_TYPING | None = None,
        policy_device: DEVICE_TYPING | None = None,
        env_device: DEVICE_TYPING | None = None,
        create_env_kwargs: dict[str, Any] | None = None,
        max_frames_per_traj: int | None = None,
        init_random_frames: int | None = None,
        reset_at_each_iter: bool = False,
        postproc: Callable[[TensorDictBase], TensorDictBase] | None = None,
        split_trajs: bool | None = None,
        exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE,
        return_same_td: bool = False,
        reset_when_done: bool = True,
        interruptor=None,
        set_truncated: bool = False,
        use_buffers: bool | None = None,
        replay_buffer: ReplayBuffer | None = None,
        extend_buffer: bool = True,
        local_init_rb: bool | None = None,
        trust_policy: bool | None = None,
        compile_policy: bool | dict[str, Any] | None = None,
        cudagraph_policy: bool | dict[str, Any] | None = None,
        no_cuda_sync: bool = False,
        weight_updater: WeightUpdaterBase
        | Callable[[], WeightUpdaterBase]
        | None = None,
        weight_sync_schemes: dict[str, WeightSyncScheme] | None = None,
        track_policy_version: bool = False,
        **kwargs,
    ):
        self.closed = True

        # Initialize environment
        env = self._init_env(create_env_fn, create_env_kwargs)

        # Initialize policy
        policy = self._init_policy(policy, policy_factory, env, trust_policy)
        self._read_compile_kwargs(compile_policy, cudagraph_policy)

        # Handle trajectory pool and validate kwargs
        self._traj_pool_val = kwargs.pop("traj_pool", None)
        if kwargs:
            raise TypeError(
                f"Keys {list(kwargs.keys())} are unknown to {type(self).__name__}."
            )

        # Set up devices and synchronization
        self._setup_devices(
            device=device,
            storing_device=storing_device,
            policy_device=policy_device,
            env_device=env_device,
            no_cuda_sync=no_cuda_sync,
        )

        self.env: EnvBase = env
        del env

        # Set up policy version tracking
        self._setup_policy_version_tracking(track_policy_version)

        # Set up replay buffer
        self._setup_replay_buffer(
            replay_buffer=replay_buffer,
            extend_buffer=extend_buffer,
            local_init_rb=local_init_rb,
            postproc=postproc,
            split_trajs=split_trajs,
            return_same_td=return_same_td,
            use_buffers=use_buffers,
        )

        self.closed = False

        # Validate reset_when_done
        if not reset_when_done:
            raise ValueError("reset_when_done is deprecated.")
        self.reset_when_done = reset_when_done
        self.n_env = self.env.batch_size.numel()

        # Register collector with policy and env
        if hasattr(policy, "register_collector"):
            policy.register_collector(self)
        if hasattr(self.env, "register_collector"):
            self.env.register_collector(self)

        # Set up policy and weights
        self._setup_policy_and_weights(policy)

        # Apply environment device
        self._apply_env_device()

        # Set up max frames per trajectory
        self._setup_max_frames_per_traj(max_frames_per_traj)

        # Validate and set total frames
        self.reset_at_each_iter = reset_at_each_iter
        self._setup_total_frames(total_frames, frames_per_batch)

        # Set up init random frames
        self._setup_init_random_frames(init_random_frames, frames_per_batch)

        # Set up postproc
        self._setup_postproc(postproc)

        # Calculate frames per batch
        self._setup_frames_per_batch(frames_per_batch)

        # Set exploration and other options
        self.exploration_type = (
            exploration_type if exploration_type else DEFAULT_EXPLORATION_TYPE
        )
        self.return_same_td = return_same_td
        self.set_truncated = set_truncated

        # Create shuttle and rollout buffers
        self._make_shuttle()
        self._maybe_make_final_rollout(make_rollout=self._use_buffers)
        self._set_truncated_keys()

        # Set split trajectories option
        if split_trajs is None:
            split_trajs = False
        self.split_trajs = split_trajs
        self._exclude_private_keys = True

        # Set up interruptor and frame tracking
        self.interruptor = interruptor
        self._frames = 0
        self._iter = -1

        # Set up weight synchronization
        self._setup_weight_sync(weight_updater, weight_sync_schemes)

    def _init_env(
        self,
        create_env_fn: EnvBase | EnvCreator | Callable[[], EnvBase],
        create_env_kwargs: dict[str, Any] | None,
    ) -> EnvBase:
        """Initialize and configure the environment."""
        from torchrl.envs.batched_envs import BatchedEnvBase

        if create_env_kwargs is None:
            create_env_kwargs = {}

        if not isinstance(create_env_fn, EnvBase):
            env = create_env_fn(**create_env_kwargs)
        else:
            env = create_env_fn
            if create_env_kwargs:
                if not isinstance(env, BatchedEnvBase):
                    raise RuntimeError(
                        "kwargs were passed to SyncDataCollector but they can't be set "
                        f"on environment of type {type(create_env_fn)}."
                    )
                env.update_kwargs(create_env_kwargs)
        return env

    def _init_policy(
        self,
        policy: TensorDictModule | Callable | None,
        policy_factory: Callable[[], Callable] | None,
        env: EnvBase,
        trust_policy: bool | None,
    ) -> TensorDictModule | Callable:
        """Initialize and configure the policy."""
        if policy is None:
            if policy_factory is not None:
                policy = policy_factory()
            else:
                policy = RandomPolicy(env.full_action_spec)
        elif policy_factory is not None:
            raise TypeError("policy_factory cannot be used with policy argument.")

        # If the underlying policy has a state_dict, keep a reference to it
        if hasattr(policy, "state_dict"):
            self._policy_w_state_dict = policy

        if trust_policy is None:
            trust_policy = isinstance(policy, (RandomPolicy, CudaGraphModule))
        self.trust_policy = trust_policy

        return policy

    def _setup_devices(
        self,
        device: DEVICE_TYPING | None,
        storing_device: DEVICE_TYPING | None,
        policy_device: DEVICE_TYPING | None,
        env_device: DEVICE_TYPING | None,
        no_cuda_sync: bool,
    ) -> None:
        """Set up devices and synchronization functions."""
        storing_device, policy_device, env_device = self._get_devices(
            storing_device=storing_device,
            policy_device=policy_device,
            env_device=env_device,
            device=device,
        )

        self.storing_device = storing_device
        self._sync_storage = self._get_sync_fn(storing_device)

        self.env_device = env_device
        self._sync_env = self._get_sync_fn(env_device)

        self.policy_device = policy_device
        self._sync_policy = self._get_sync_fn(policy_device)

        self.device = device
        self.no_cuda_sync = no_cuda_sync
        self._cast_to_policy_device = self.policy_device != self.env_device

    def _get_sync_fn(self, device: torch.device | None) -> Callable:
        """Get the appropriate synchronization function for a device."""
        if device is not None and device.type != "cuda":
            # Cuda handles sync
            if torch.cuda.is_available():
                return torch.cuda.synchronize
            elif torch.backends.mps.is_available() and hasattr(torch, "mps"):
                return torch.mps.synchronize
            elif hasattr(torch, "npu") and torch.npu.is_available():
                return torch.npu.synchronize
            elif device.type == "cpu":
                return _do_nothing
            else:
                raise RuntimeError("Non supported device")
        else:
            return _do_nothing

    def _setup_policy_version_tracking(
        self, track_policy_version: bool | PolicyVersion
    ) -> None:
        """Set up policy version tracking if requested."""
        self.policy_version_tracker = track_policy_version
        if isinstance(track_policy_version, bool) and track_policy_version:
            from torchrl.envs.batched_envs import BatchedEnvBase

            if isinstance(self.env, BatchedEnvBase):
                raise RuntimeError(
                    "BatchedEnvBase is not supported for policy version tracking. Please add the PolicyVersion transform to the environment manually, "
                    "and pass that transform to the collector."
                )
            self.policy_version_tracker = PolicyVersion()
            self.env = self.env.append_transform(self.policy_version_tracker)  # type: ignore
        elif hasattr(track_policy_version, "increment_version"):
            self.policy_version_tracker = track_policy_version
            self.env = self.env.append_transform(self.policy_version_tracker)  # type: ignore
        else:
            self.policy_version_tracker = None

    def _setup_replay_buffer(
        self,
        replay_buffer: ReplayBuffer | None,
        extend_buffer: bool,
        local_init_rb: bool | None,
        postproc: Callable | None,
        split_trajs: bool | None,
        return_same_td: bool,
        use_buffers: bool | None,
    ) -> None:
        """Set up replay buffer configuration and validate compatibility."""
        self.replay_buffer = replay_buffer
        self.extend_buffer = extend_buffer

        # Handle local_init_rb deprecation
        if local_init_rb is None:
            local_init_rb = False
            if replay_buffer is not None and not local_init_rb:
                warnings.warn(
                    "local_init_rb=False is deprecated and will be removed in v0.12. "
                    "The new storage-level initialization provides better performance.",
                    FutureWarning,
                )
        self.local_init_rb = local_init_rb

        # Validate replay buffer compatibility
        if self.replay_buffer is not None and not self._ignore_rb:
            if postproc is not None and not self.extend_buffer:
                raise TypeError(
                    "postproc must be None when a replay buffer is passed, or extend_buffer must be set to True."
                )
            if split_trajs not in (None, False) and not self.extend_buffer:
                raise TypeError(
                    "split_trajs must be None/False when a replay buffer is passed, or extend_buffer must be set to True."
                )
            if return_same_td:
                raise TypeError(
                    "return_same_td must be False when a replay buffer is passed, or extend_buffer must be set to True."
                )
            if use_buffers:
                raise TypeError("replay_buffer is exclusive with use_buffers.")

        if use_buffers is None:
            use_buffers = not self.env._has_dynamic_specs and self.replay_buffer is None
        self._use_buffers = use_buffers

    def _setup_policy_and_weights(self, policy: TensorDictModule | Callable) -> None:
        """Set up policy, wrapped policy, and extract weights."""
        self._original_policy = policy
        policy, self.get_weights_fn = self._get_policy_and_device(policy=policy)

        if not self.trust_policy:
            self.policy = policy
            env = getattr(self, "env", None)
            try:
                wrapped_policy = _make_compatible_policy(
                    policy=policy,
                    observation_spec=getattr(env, "observation_spec", None),
                    env=self.env,
                )
            except (TypeError, AttributeError, ValueError) as err:
                raise TypeError(
                    "Failed to wrap the policy. If the policy needs to be trusted, set trust_policy=True."
                ) from err
            self._wrapped_policy = wrapped_policy
        else:
            self.policy = self._wrapped_policy = policy

        # Extract policy weights
        if isinstance(self._wrapped_policy, nn.Module):
            self.policy_weights = TensorDict.from_module(
                self._wrapped_policy, as_module=True
            ).data
        else:
            self.policy_weights = TensorDict()

        # Apply compilation/cudagraph
        if self.compiled_policy:
            self._wrapped_policy = compile_with_warmup(
                self._wrapped_policy, **self.compiled_policy_kwargs
            )
        if self.cudagraphed_policy:
            self._wrapped_policy = CudaGraphModule(
                self._wrapped_policy,
                in_keys=[],
                out_keys=[],
                device=self.policy_device,
                **self.cudagraphed_policy_kwargs,
            )

    def _apply_env_device(self) -> None:
        """Apply device to environment if specified."""
        if self.env_device:
            self.env: EnvBase = self.env.to(self.env_device)
        elif self.env.device is not None:
            # Use the device of the env if none was provided
            self.env_device = self.env.device

        # Check if we need to cast to env device
        self._cast_to_env_device = self._cast_to_policy_device or (
            self.env.device != self.storing_device
        )

    def _setup_max_frames_per_traj(self, max_frames_per_traj: int | None) -> None:
        """Set up maximum frames per trajectory and add StepCounter if needed."""
        self.max_frames_per_traj = (
            int(max_frames_per_traj) if max_frames_per_traj is not None else 0
        )
        if self.max_frames_per_traj is not None and self.max_frames_per_traj > 0:
            # Check that there is no StepCounter yet
            for key in self.env.output_spec.keys(True, True):
                if isinstance(key, str):
                    key = (key,)
                if "step_count" in key:
                    raise ValueError(
                        "A 'step_count' key is already present in the environment "
                        "and the 'max_frames_per_traj' argument may conflict with "
                        "a 'StepCounter' that has already been set. "
                        "Possible solutions: Set max_frames_per_traj to 0 or "
                        "remove the StepCounter limit from the environment transforms."
                    )
            self.env = TransformedEnv(
                self.env, StepCounter(max_steps=self.max_frames_per_traj)
            )

    def _setup_total_frames(self, total_frames: int, frames_per_batch: int) -> None:
        """Validate and set total frames."""
        if total_frames is None or total_frames < 0:
            total_frames = float("inf")
        else:
            remainder = total_frames % frames_per_batch
            if remainder != 0 and rl_warnings():
                warnings.warn(
                    f"total_frames ({total_frames}) is not exactly divisible by frames_per_batch ({frames_per_batch}). "
                    f"This means {frames_per_batch - remainder} additional frames will be collected."
                    "To silence this message, set the environment variable RL_WARNINGS to False."
                )
        self.total_frames = (
            int(total_frames) if total_frames != float("inf") else total_frames
        )

    def _setup_init_random_frames(
        self, init_random_frames: int | None, frames_per_batch: int
    ) -> None:
        """Set up initial random frames."""
        self.init_random_frames = (
            int(init_random_frames) if init_random_frames not in (None, -1) else 0
        )
        if (
            init_random_frames not in (-1, None, 0)
            and init_random_frames % frames_per_batch != 0
            and rl_warnings()
        ):
            warnings.warn(
                f"init_random_frames ({init_random_frames}) is not exactly a multiple of frames_per_batch ({frames_per_batch}), "
                f" this results in more init_random_frames than requested"
                f" ({-(-init_random_frames // frames_per_batch) * frames_per_batch})."
                "To silence this message, set the environment variable RL_WARNINGS to False."
            )

    def _setup_postproc(self, postproc: Callable | None) -> None:
        """Set up post-processing transform."""
        self.postproc = postproc
        if (
            self.postproc is not None
            and hasattr(self.postproc, "to")
            and self.storing_device
        ):
            postproc = self.postproc.to(self.storing_device)
            if postproc is not self.postproc and postproc is not None:
                self.postproc = postproc

    def _setup_frames_per_batch(self, frames_per_batch: int) -> None:
        """Calculate and validate frames per batch."""
        if frames_per_batch % self.n_env != 0 and rl_warnings():
            warnings.warn(
                f"frames_per_batch ({frames_per_batch}) is not exactly divisible by the number of batched environments ({self.n_env}), "
                f" this results in more frames_per_batch per iteration that requested"
                f" ({-(-frames_per_batch // self.n_env) * self.n_env}). "
                "To silence this message, set the environment variable RL_WARNINGS to False."
            )
        self.frames_per_batch = -(-frames_per_batch // self.n_env)
        self.requested_frames_per_batch = self.frames_per_batch * self.n_env

    def _setup_weight_sync(
        self,
        weight_updater: WeightUpdaterBase | Callable | None,
        weight_sync_schemes: dict[str, WeightSyncScheme] | None,
    ) -> None:
        """Set up weight synchronization system."""
        if weight_sync_schemes is not None:
            # Use new simplified weight synchronization system
            self._weight_sync_schemes = weight_sync_schemes
            self._weight_senders = {}
            # For single-process collectors, we don't need senders/receivers
            # The policy is local and changes are immediately visible
            # Senders will be set up in multiprocess collectors during _run_processes
            self.weight_updater = None  # Don't use legacy system
        elif weight_updater is not None:
            # Use legacy weight updater system if explicitly provided
            if not isinstance(weight_updater, WeightUpdaterBase):
                if callable(weight_updater):
                    weight_updater = weight_updater()
                else:
                    raise TypeError(
                        f"weight_updater must be a subclass of WeightUpdaterBase. Got {type(weight_updater)} instead."
                    )
            warnings.warn(
                "Using WeightUpdaterBase is deprecated. Please use weight_sync_schemes instead. "
                "This will be removed in a future version.",
                DeprecationWarning,
                stacklevel=2,
            )
            self.weight_updater = weight_updater
            self._weight_sync_schemes = None
            self._weight_senders = {}
        else:
            # No weight sync needed for single-process collectors
            self.weight_updater = None
            self._weight_sync_schemes = None
            self._weight_senders = {}

    @property
    def _traj_pool(self):
        pool = getattr(self, "_traj_pool_val", None)
        if pool is None:
            pool = self._traj_pool_val = _TrajectoryPool()
        return pool

    def _make_shuttle(self):
        # Shuttle is a deviceless tensordict that just carried data from env to policy and policy to env
        with torch.no_grad():
            self._shuttle = self.env.reset()
        if self.policy_device != self.env_device or self.env_device is None:
            self._shuttle_has_no_device = True
            self._shuttle.clear_device_()
        else:
            self._shuttle_has_no_device = False

        traj_ids = self._traj_pool.get_traj_and_increment(
            self.n_env, device=self.storing_device
        ).view(self.env.batch_size)
        self._shuttle.set(
            ("collector", "traj_ids"),
            traj_ids,
        )

    def _maybe_make_final_rollout(self, make_rollout: bool):
        if make_rollout:
            with torch.no_grad():
                self._final_rollout = self.env.fake_tensordict()

            # If storing device is not None, we use this to cast the storage.
            # If it is None and the env and policy are on the same device,
            # the storing device is already the same as those, so we don't need
            # to consider this use case.
            # In all other cases, we can't really put a device on the storage,
            # since at least one data source has a device that is not clear.
            if self.storing_device:
                self._final_rollout = self._final_rollout.to(
                    self.storing_device, non_blocking=True
                )
            else:
                # erase all devices
                self._final_rollout.clear_device_()

        # If the policy has a valid spec, we use it
        self._policy_output_keys = set()
        if (
            make_rollout
            and hasattr(self._wrapped_policy, "spec")
            and self._wrapped_policy.spec is not None
            and all(v is not None for v in self._wrapped_policy.spec.values(True, True))
        ):
            if any(
                key not in self._final_rollout.keys(isinstance(key, tuple))
                for key in self._wrapped_policy.spec.keys(True, True)
            ):
                # if policy spec is non-empty, all the values are not None and the keys
                # match the out_keys we assume the user has given all relevant information
                # the policy could have more keys than the env:
                policy_spec = self._wrapped_policy.spec
                if policy_spec.ndim < self._final_rollout.ndim:
                    policy_spec = policy_spec.expand(self._final_rollout.shape)
                for key, spec in policy_spec.items(True, True):
                    self._policy_output_keys.add(key)
                    if key in self._final_rollout.keys(True):
                        continue
                    self._final_rollout.set(key, spec.zero())
        elif (
            not make_rollout
            and hasattr(self._wrapped_policy, "out_keys")
            and self._wrapped_policy.out_keys
        ):
            self._policy_output_keys = list(self._wrapped_policy.out_keys)
        else:
            if make_rollout:
                # otherwise, we perform a small number of steps with the policy to
                # determine the relevant keys with which to pre-populate _final_rollout.
                # This is the safest thing to do if the spec has None fields or if there is
                # no spec at all.
                # See #505 for additional context.
                self._final_rollout.update(self._shuttle.copy())
            with torch.no_grad():
                policy_input = self._shuttle.copy()
                if self.policy_device:
                    policy_input = policy_input.to(self.policy_device)
                # we cast to policy device, we'll deal with the device later
                policy_input_copy = policy_input.copy()
                policy_input_clone = (
                    policy_input.clone()
                )  # to test if values have changed in-place
                if self.compiled_policy:
                    cudagraph_mark_step_begin()
                policy_output = self._wrapped_policy(policy_input)

                # check that we don't have exclusive keys, because they don't appear in keys
                def check_exclusive(val):
                    if (
                        isinstance(val, LazyStackedTensorDict)
                        and val._has_exclusive_keys
                    ):
                        raise RuntimeError(
                            "LazyStackedTensorDict with exclusive keys are not permitted in collectors. "
                            "Consider using a placeholder for missing keys."
                        )

                policy_output._fast_apply(
                    check_exclusive, call_on_nested=True, filter_empty=True
                )

                # Use apply, because it works well with lazy stacks
                # Edge-case of this approach: the policy may change the values in-place and only by a tiny bit
                # or occasionally. In these cases, the keys will be missed (we can't detect if the policy has
                # changed them here).
                # This will cause a failure to update entries when policy and env device mismatch and
                # casting is necessary.
                def filter_policy(name, value_output, value_input, value_input_clone):
                    if (value_input is None) or (
                        (value_output is not value_input)
                        and (
                            value_output.device != value_input_clone.device
                            or ~torch.isclose(value_output, value_input_clone).any()
                        )
                    ):
                        return value_output

                filtered_policy_output = policy_output.apply(
                    filter_policy,
                    policy_input_copy,
                    policy_input_clone,
                    default=None,
                    filter_empty=True,
                    named=True,
                )
                self._policy_output_keys = list(
                    self._policy_output_keys.union(
                        set(filtered_policy_output.keys(True, True))
                    )
                )
                if make_rollout:
                    self._final_rollout.update(
                        policy_output.select(*self._policy_output_keys)
                    )
                del filtered_policy_output, policy_output, policy_input

        _env_output_keys = []
        for spec in ["full_observation_spec", "full_done_spec", "full_reward_spec"]:
            _env_output_keys += list(self.env.output_spec[spec].keys(True, True))
        self._env_output_keys = _env_output_keys
        if make_rollout:
            self._final_rollout = (
                self._final_rollout.unsqueeze(-1)
                .expand(*self.env.batch_size, self.frames_per_batch)
                .clone()
                .zero_()
            )

            # in addition to outputs of the policy, we add traj_ids to
            # _final_rollout which will be collected during rollout
            self._final_rollout.set(
                ("collector", "traj_ids"),
                torch.zeros(
                    *self._final_rollout.batch_size,
                    dtype=torch.int64,
                    device=self.storing_device,
                ),
            )
            self._final_rollout.refine_names(..., "time")

    def _set_truncated_keys(self):
        self._truncated_keys = []
        if self.set_truncated:
            if not any(_ends_with(key, "truncated") for key in self.env.done_keys):
                raise RuntimeError(
                    "set_truncated was set to True but no truncated key could be found "
                    "in the environment. Make sure the truncated keys are properly set using "
                    "`env.add_truncated_keys()` before passing the env to the collector."
                )
            self._truncated_keys = [
                key for key in self.env.done_keys if _ends_with(key, "truncated")
            ]

    @classmethod
    def _get_devices(
        cls,
        *,
        storing_device: torch.device,
        policy_device: torch.device,
        env_device: torch.device,
        device: torch.device,
    ):
        device = _make_ordinal_device(torch.device(device) if device else device)
        storing_device = _make_ordinal_device(
            torch.device(storing_device) if storing_device else device
        )
        policy_device = _make_ordinal_device(
            torch.device(policy_device) if policy_device else device
        )
        env_device = _make_ordinal_device(
            torch.device(env_device) if env_device else device
        )
        if storing_device is None and (env_device == policy_device):
            storing_device = env_device
        return storing_device, policy_device, env_device

    # for RPC
    def next(self):
        return super().next()

    # for RPC
    def update_policy_weights_(
        self,
        policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None,
        *,
        worker_ids: int | list[int] | torch.device | list[torch.device] | None = None,
        **kwargs,
    ) -> None:
        if "policy_weights" in kwargs:
            warnings.warn(
                "`policy_weights` is deprecated. Use `policy_or_weights` instead.",
                DeprecationWarning,
            )
            policy_or_weights = kwargs.pop("policy_weights")

        super().update_policy_weights_(
            policy_or_weights=policy_or_weights, worker_ids=worker_ids, **kwargs
        )

    def set_seed(self, seed: int, static_seed: bool = False) -> int:
        """Sets the seeds of the environments stored in the DataCollector.

        Args:
            seed (int): integer representing the seed to be used for the environment.
            static_seed(bool, optional): if ``True``, the seed is not incremented.
                Defaults to False

        Returns:
            Output seed. This is useful when more than one environment is contained in the DataCollector, as the
            seed will be incremented for each of these. The resulting seed is the seed of the last environment.

        Examples:
            >>> from torchrl.envs import ParallelEnv
            >>> from torchrl.envs.libs.gym import GymEnv
            >>> from tensordict.nn import TensorDictModule
            >>> from torch import nn
            >>> env_fn = lambda: GymEnv("Pendulum-v1")
            >>> env_fn_parallel = ParallelEnv(6, env_fn)
            >>> policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"])
            >>> collector = SyncDataCollector(env_fn_parallel, policy, total_frames=300, frames_per_batch=100)
            >>> out_seed = collector.set_seed(1)  # out_seed = 6

        """
        out = self.env.set_seed(seed, static_seed=static_seed)
        return out

    def _increment_frames(self, numel):
        self._frames += numel
        completed = self._frames >= self.total_frames
        if completed:
            self.env.close()
        return completed

    def iterator(self) -> Iterator[TensorDictBase]:
        """Iterates through the DataCollector.

        Yields: TensorDictBase objects containing (chunks of) trajectories

        """
        if (
            not self.no_cuda_sync
            and self.storing_device
            and self.storing_device.type == "cuda"
        ):
            stream = torch.cuda.Stream(self.storing_device, priority=-1)
            event = stream.record_event()
            streams = [stream]
            events = [event]
        elif not self.no_cuda_sync and self.storing_device is None:
            streams = []
            events = []
            # this way of checking cuda is robust to lazy stacks with mismatching shapes
            cuda_devices = set()

            def cuda_check(tensor: torch.Tensor):
                if tensor.is_cuda:
                    cuda_devices.add(tensor.device)

            if not self._use_buffers:
                # This may be a bit dangerous as `torch.device("cuda")` may not have a precise
                # device associated, whereas `tensor.device` always has
                for spec in self.env.specs.values(True, True):
                    if spec.device is not None and spec.device.type == "cuda":
                        if ":" not in str(spec.device):
                            raise RuntimeError(
                                "A cuda spec did not have a device associated. Make sure to "
                                "pass `'cuda:device_num'` to each spec device."
                            )
                        cuda_devices.add(spec.device)
            else:
                self._final_rollout.apply(cuda_check, filter_empty=True)
            for device in cuda_devices:
                streams.append(torch.cuda.Stream(device, priority=-1))
                events.append(streams[-1].record_event())
        else:
            streams = []
            events = []
        with contextlib.ExitStack() as stack:
            for stream in streams:
                stack.enter_context(torch.cuda.stream(stream))

            while self._frames < self.total_frames:
                self._iter += 1
                if self.verbose:
                    torchrl_logger.info("Collector: rollout.")
                tensordict_out = self.rollout()
                if tensordict_out is None:
                    # if a replay buffer is passed and self.extend_buffer=False, there is no tensordict_out
                    #  frames are updated within the rollout function
                    if self.verbose:
                        torchrl_logger.info("Collector: No tensordict_out. Yielding.")
                    yield
                    continue
                self._increment_frames(tensordict_out.numel())
                tensordict_out = self._postproc(tensordict_out)
                if self.verbose:
                    torchrl_logger.info("Collector: postproc done.")
                if self.return_same_td:
                    # This is used with multiprocessed collectors to use the buffers
                    # stored in the tensordict.
                    if events:
                        for event in events:
                            event.record()
                            event.synchronize()
                    yield tensordict_out
                elif self.replay_buffer is not None and not self._ignore_rb:
                    self.replay_buffer.extend(tensordict_out)
                    if self.verbose:
                        torchrl_logger.info(
                            f"Collector: Added {tensordict_out.numel()} frames to replay buffer. "
                            "Buffer write count: {self.replay_buffer.write_count}. Yielding."
                        )
                    yield
                else:
                    # we must clone the values, as the tensordict is updated in-place.
                    # otherwise the following code may break:
                    # >>> for i, data in enumerate(collector):
                    # >>>      if i == 0:
                    # >>>          data0 = data
                    # >>>      elif i == 1:
                    # >>>          data1 = data
                    # >>>      else:
                    # >>>          break
                    # >>> assert data0["done"] is not data1["done"]
                    yield tensordict_out.clone()

    def start(self):
        """Starts the collector in a separate thread for asynchronous data collection.

        The collected data is stored in the provided replay buffer. This method is useful when you want to decouple data
        collection from training, allowing your training loop to run independently of the data collection process.

        Raises:
            RuntimeError: If no replay buffer is defined during the collector's initialization.

        Example:
            >>> import time
            >>> from functools import partial
            >>>
            >>> import tqdm
            >>>
            >>> from torchrl.collectors import SyncDataCollector, RandomPolicy
            >>> from torchrl.data import LazyTensorStorage, ReplayBuffer
            >>> from torchrl.envs import GymEnv, set_gym_backend
            >>> import ale_py
            >>>
            >>> # Set the gym backend to gymnasium
            >>> set_gym_backend("gymnasium").set()
            >>>
            >>> if __name__ == "__main__":
            ...     # Create a random policy for the Pong environment
            ...     env = GymEnv("ALE/Pong-v5")
            ...     policy = RandomPolicy(env.action_spec)
            ...
            ...     # Initialize a shared replay buffer
            ...     rb = ReplayBuffer(storage=LazyTensorStorage(1000), shared=True)
            ...
            ...     # Create a synchronous data collector
            ...     collector = SyncDataCollector(
            ...         env,
            ...         policy=policy,
            ...         replay_buffer=rb,
            ...         frames_per_batch=256,
            ...         total_frames=-1,
            ...     )
            ...
            ...     # Progress bar to track the number of collected frames
            ...     pbar = tqdm.tqdm(total=100_000)
            ...
            ...     # Start the collector asynchronously
            ...     collector.start()
            ...
            ...     # Track the write count of the replay buffer
            ...     prec_wc = 0
            ...     while True:
            ...         wc = rb.write_count
            ...         c = wc - prec_wc
            ...         prec_wc = wc
            ...
            ...         # Update the progress bar
            ...         pbar.update(c)
            ...         pbar.set_description(f"Write Count: {rb.write_count}")
            ...
            ...         # Check the write count every 0.5 seconds
            ...         time.sleep(0.5)
            ...
            ...         # Stop when the desired number of frames is reached
            ...         if rb.write_count . 100_000:
            ...             break
            ...
            ...     # Shut down the collector
            ...     collector.async_shutdown()
        """
        if self.replay_buffer is None:
            raise RuntimeError("Replay buffer must be defined for execution.")
        if not self.is_running():
            self._stop = False
            self._thread = threading.Thread(target=self._run_iterator)
            self._thread.daemon = (
                True  # So that the thread dies when the main program exits
            )
            self._thread.start()

    def _run_iterator(self):
        for _ in self:
            if self._stop:
                return

    def is_running(self):
        return hasattr(self, "_thread") and self._thread.is_alive()

    def async_shutdown(
        self, timeout: float | None = None, close_env: bool = True
    ) -> None:
        """Finishes processes started by ray.init() during async execution."""
        self._stop = True
        if hasattr(self, "_thread") and self._thread.is_alive():
            self._thread.join(timeout=timeout)
        self.shutdown(close_env=close_env)

    def _postproc(self, tensordict_out):
        if self.split_trajs:
            tensordict_out = split_trajectories(tensordict_out, prefix="collector")
        if self.postproc is not None:
            tensordict_out = self.postproc(tensordict_out)
        if self._exclude_private_keys:

            def is_private(key):
                if isinstance(key, str) and key.startswith("_"):
                    return True
                if isinstance(key, tuple) and any(_key.startswith("_") for _key in key):
                    return True
                return False

            excluded_keys = [
                key for key in tensordict_out.keys(True) if is_private(key)
            ]
            tensordict_out = tensordict_out.exclude(*excluded_keys, inplace=True)
        return tensordict_out

    def _update_traj_ids(self, env_output) -> None:
        # we can't use the reset keys because they're gone
        traj_sop = _aggregate_end_of_traj(
            env_output.get("next"), done_keys=self.env.done_keys
        )
        if traj_sop.any():
            device = self.storing_device

            traj_ids = self._shuttle.get(("collector", "traj_ids"))
            if device is not None:
                traj_ids = traj_ids.to(device)
                traj_sop = traj_sop.to(device)
            elif traj_sop.device != traj_ids.device:
                traj_sop = traj_sop.to(traj_ids.device)

            pool = self._traj_pool
            new_traj = pool.get_traj_and_increment(
                traj_sop.sum(), device=traj_sop.device
            )
            traj_ids = traj_ids.masked_scatter(traj_sop, new_traj)
            self._shuttle.set(("collector", "traj_ids"), traj_ids)

    @torch.no_grad()
    def rollout(self) -> TensorDictBase:
        """Computes a rollout in the environment using the provided policy.

        Returns:
            TensorDictBase containing the computed rollout.

        """
        if self.reset_at_each_iter:
            self._shuttle.update(self.env.reset())

        # self._shuttle.fill_(("collector", "step_count"), 0)
        if self._use_buffers:
            self._final_rollout.fill_(("collector", "traj_ids"), -1)
        else:
            pass
        tensordicts = []
        with set_exploration_type(self.exploration_type):
            for t in range(self.frames_per_batch):
                if (
                    self.init_random_frames is not None
                    and self._frames < self.init_random_frames
                ):
                    self.env.rand_action(self._shuttle)
                    if (
                        self.policy_device is not None
                        and self.policy_device != self.env_device
                    ):
                        # TODO: This may break with exclusive / ragged lazy stacks
                        self._shuttle.apply(
                            lambda name, val: val.to(
                                device=self.policy_device, non_blocking=True
                            )
                            if name in self._policy_output_keys
                            else val,
                            out=self._shuttle,
                            named=True,
                            nested_keys=True,
                        )
                else:
                    if self._cast_to_policy_device:
                        if self.policy_device is not None:
                            # This is unsafe if the shuttle is in pin_memory -- otherwise cuda will be happy with non_blocking
                            non_blocking = (
                                not self.no_cuda_sync
                                or self.policy_device.type == "cuda"
                            )
                            policy_input = self._shuttle.to(
                                self.policy_device,
                                non_blocking=non_blocking,
                            )
                            if not self.no_cuda_sync:
                                self._sync_policy()
                        elif self.policy_device is None:
                            # we know the tensordict has a device otherwise we would not be here
                            # we can pass this, clear_device_ must have been called earlier
                            # policy_input = self._shuttle.clear_device_()
                            policy_input = self._shuttle
                    else:
                        policy_input = self._shuttle
                    # we still do the assignment for security
                    if self.compiled_policy:
                        cudagraph_mark_step_begin()
                    policy_output = self._wrapped_policy(policy_input)
                    if self.compiled_policy:
                        policy_output = policy_output.clone()
                    if self._shuttle is not policy_output:
                        # ad-hoc update shuttle
                        self._shuttle.update(
                            policy_output, keys_to_update=self._policy_output_keys
                        )

                if self._cast_to_env_device:
                    if self.env_device is not None:
                        non_blocking = (
                            not self.no_cuda_sync or self.env_device.type == "cuda"
                        )
                        env_input = self._shuttle.to(
                            self.env_device, non_blocking=non_blocking
                        )
                        if not self.no_cuda_sync:
                            self._sync_env()
                    elif self.env_device is None:
                        # we know the tensordict has a device otherwise we would not be here
                        # we can pass this, clear_device_ must have been called earlier
                        # env_input = self._shuttle.clear_device_()
                        env_input = self._shuttle
                else:
                    env_input = self._shuttle
                env_output, env_next_output = self.env.step_and_maybe_reset(env_input)

                if self._shuttle is not env_output:
                    # ad-hoc update shuttle
                    next_data = env_output.get("next")
                    if self._shuttle_has_no_device:
                        # Make sure
                        next_data.clear_device_()
                    self._shuttle.set("next", next_data)

                if self.verbose:
                    torchrl_logger.info(
                        f"Collector: Rollout step completed {self._iter=}."
                    )
                if (
                    self.replay_buffer is not None
                    and not self._ignore_rb
                    and not self.extend_buffer
                ):
                    if self.verbose:
                        torchrl_logger.info(
                            f"Collector: Adding {env_output.numel()} frames to replay buffer using add()."
                        )
                    self.replay_buffer.add(self._shuttle)
                    if self._increment_frames(self._shuttle.numel()):
                        return
                else:
                    if self.storing_device is not None:
                        if self.verbose:
                            torchrl_logger.info(
                                f"Collector: Moving to {self.storing_device} and adding to queue."
                            )
                        non_blocking = (
                            not self.no_cuda_sync or self.storing_device.type == "cuda"
                        )
                        tensordicts.append(
                            self._shuttle.to(
                                self.storing_device, non_blocking=non_blocking
                            )
                        )
                        if not self.no_cuda_sync:
                            self._sync_storage()
                    else:
                        if self.verbose:
                            torchrl_logger.info(
                                "Collector: Adding to queue (no device)."
                            )
                        tensordicts.append(self._shuttle)

                # carry over collector data without messing up devices
                collector_data = self._shuttle.get("collector").copy()
                self._shuttle = env_next_output
                if self._shuttle_has_no_device:
                    self._shuttle.clear_device_()
                self._shuttle.set("collector", collector_data)
                self._update_traj_ids(env_output)

                if (
                    self.interruptor is not None
                    and self.interruptor.collection_stopped()
                ):
                    if self.verbose:
                        torchrl_logger.info("Collector: Interruptor stopped.")
                    if (
                        self.replay_buffer is not None
                        and not self._ignore_rb
                        and not self.extend_buffer
                    ):
                        return
                    result = self._final_rollout
                    if self._use_buffers:
                        try:
                            torch.stack(
                                tensordicts,
                                self._final_rollout.ndim - 1,
                                out=self._final_rollout[..., : t + 1],
                            )
                        except RuntimeError:
                            with self._final_rollout.unlock_():
                                torch.stack(
                                    tensordicts,
                                    self._final_rollout.ndim - 1,
                                    out=self._final_rollout[..., : t + 1],
                                )
                    else:
                        result = TensorDict.maybe_dense_stack(tensordicts, dim=-1)
                    break
            else:
                if self._use_buffers:
                    torchrl_logger.info("Returning final rollout within buffer.")
                    result = self._final_rollout
                    try:
                        result = torch.stack(
                            tensordicts,
                            self._final_rollout.ndim - 1,
                            out=self._final_rollout,
                        )

                    except RuntimeError:
                        with self._final_rollout.unlock_():
                            result = torch.stack(
                                tensordicts,
                                self._final_rollout.ndim - 1,
                                out=self._final_rollout,
                            )
                elif (
                    self.replay_buffer is not None
                    and not self._ignore_rb
                    and not self.extend_buffer
                ):
                    return
                else:
                    torchrl_logger.info(
                        "Returning final rollout with NO buffer (maybe_dense_stack)."
                    )
                    result = TensorDict.maybe_dense_stack(tensordicts, dim=-1)
                    result.refine_names(..., "time")

        return self._maybe_set_truncated(result)

    def _maybe_set_truncated(self, final_rollout):
        last_step = (slice(None),) * (final_rollout.ndim - 1) + (-1,)
        for truncated_key in self._truncated_keys:
            truncated = final_rollout["next", truncated_key]
            truncated[last_step] = True
            final_rollout["next", truncated_key] = truncated
            done = final_rollout["next", _replace_last(truncated_key, "done")]
            final_rollout["next", _replace_last(truncated_key, "done")] = (
                done | truncated
            )
        return final_rollout

    @torch.no_grad()
    def reset(self, index=None, **kwargs) -> None:
        """Resets the environments to a new initial state."""
        # metadata
        collector_metadata = self._shuttle.get("collector").clone()
        if index is not None:
            # check that the env supports partial reset
            if prod(self.env.batch_size) == 0:
                raise RuntimeError("resetting unique env with index is not permitted.")
            for reset_key, done_keys in zip(
                self.env.reset_keys, self.env.done_keys_groups
            ):
                _reset = torch.zeros(
                    self.env.full_done_spec[done_keys[0]].shape,
                    dtype=torch.bool,
                    device=self.env.device,
                )
                _reset[index] = 1
                self._shuttle.set(reset_key, _reset)
        else:
            _reset = None
            self._shuttle.zero_()

        self._shuttle.update(self.env.reset(**kwargs), inplace=True)
        collector_metadata["traj_ids"] = (
            collector_metadata["traj_ids"] - collector_metadata["traj_ids"].min()
        )
        self._shuttle["collector"] = collector_metadata

    def shutdown(
        self,
        timeout: float | None = None,
        close_env: bool = True,
        raise_on_error: bool = True,
    ) -> None:
        """Shuts down all workers and/or closes the local environment.

        Args:
            timeout (float, optional): The timeout for closing pipes between workers.
                No effect for this class.
            close_env (bool, optional): Whether to close the environment. Defaults to `True`.
            raise_on_error (bool, optional): Whether to raise an error if the shutdown fails. Defaults to `True`.
        """
        try:
            if not self.closed:
                self.closed = True
                del self._shuttle
                if self._use_buffers:
                    del self._final_rollout
                if close_env and not self.env.is_closed:
                    self.env.close(raise_if_closed=raise_on_error)
                del self.env
            return
        except Exception as e:
            if raise_on_error:
                raise e
            else:
                pass

    def __del__(self):
        try:
            self.shutdown()
        except Exception:
            # an AttributeError will typically be raised if the collector is deleted when the program ends.
            # In the future, insignificant changes to the close method may change the error type.
            # We excplicitely assume that any error raised during closure in
            # __del__ will not affect the program.
            pass

    def state_dict(self) -> OrderedDict:
        """Returns the local state_dict of the data collector (environment and policy).

        Returns:
            an ordered dictionary with fields :obj:`"policy_state_dict"` and
            `"env_state_dict"`.

        """
        from torchrl.envs.batched_envs import BatchedEnvBase

        if isinstance(self.env, TransformedEnv):
            env_state_dict = self.env.transform.state_dict()
        elif isinstance(self.env, BatchedEnvBase):
            env_state_dict = self.env.state_dict()
        else:
            env_state_dict = OrderedDict()

        if hasattr(self, "_policy_w_state_dict"):
            policy_state_dict = self._policy_w_state_dict.state_dict()
            state_dict = OrderedDict(
                policy_state_dict=policy_state_dict,
                env_state_dict=env_state_dict,
            )
        else:
            state_dict = OrderedDict(env_state_dict=env_state_dict)

        state_dict.update({"frames": self._frames, "iter": self._iter})

        return state_dict

    def load_state_dict(self, state_dict: OrderedDict, **kwargs) -> None:
        """Loads a state_dict on the environment and policy.

        Args:
            state_dict (OrderedDict): ordered dictionary containing the fields
                `"policy_state_dict"` and :obj:`"env_state_dict"`.

        """
        strict = kwargs.get("strict", True)
        if strict or "env_state_dict" in state_dict:
            self.env.load_state_dict(state_dict["env_state_dict"], **kwargs)
        if strict or "policy_state_dict" in state_dict:
            if not hasattr(self, "_policy_w_state_dict"):
                raise ValueError(
                    "Underlying policy does not have state_dict to load policy_state_dict into."
                )
            self._policy_w_state_dict.load_state_dict(
                state_dict["policy_state_dict"], **kwargs
            )
        self._frames = state_dict["frames"]
        self._iter = state_dict["iter"]

    def __repr__(self) -> str:
        try:
            env_str = indent(f"env={self.env}", 4 * " ")
            policy_str = indent(f"policy={self._wrapped_policy}", 4 * " ")
            td_out_str = repr(getattr(self, "_final_rollout", None))
            if len(td_out_str) > 50:
                td_out_str = td_out_str[:50] + "..."
            td_out_str = indent(f"td_out={td_out_str}", 4 * " ")
            string = (
                f"{self.__class__.__name__}("
                f"\n{env_str},"
                f"\n{policy_str},"
                f"\n{td_out_str},"
                f"\nexploration={self.exploration_type})"
            )
            return string
        except Exception:
            return f"{type(self).__name__}(not_init)"

    def increment_version(self):
        """Increment the policy version."""
        if self.policy_version_tracker is not None:
            if not hasattr(self.policy_version_tracker, "increment_version"):
                raise RuntimeError(
                    "Policy version tracker is not a PolicyVersion instance. Please pass a PolicyVersion instance to the collector."
                )
            self.policy_version_tracker.increment_version()

    @property
    def policy_version(self) -> str | int | None:
        """The current policy version."""
        if not hasattr(self.policy_version_tracker, "version"):
            return None
        return self.policy_version_tracker.version

    def get_policy_version(self) -> str | int | None:
        """Get the current policy version.

        This method exists to support remote calls in Ray actors, since properties
        cannot be accessed directly through Ray's RPC mechanism.

        Returns:
            The current version number (int) or UUID (str), or None if version tracking is disabled.
        """
        return self.policy_version

    def getattr_policy(self, attr):
        """Get an attribute from the policy."""
        # send command to policy to return the attr
        return getattr(self._wrapped_policy, attr)

    def getattr_env(self, attr):
        """Get an attribute from the environment."""
        # send command to env to return the attr
        return getattr(self.env, attr)

    def getattr_rb(self, attr):
        """Get an attribute from the replay buffer."""
        # send command to rb to return the attr
        return getattr(self.replay_buffer, attr)

    def get_model(self, model_id: str):
        """Get model instance by ID (for weight sync schemes).

        Args:
            model_id: Model identifier (e.g., "policy", "value_net")

        Returns:
            The model instance

        Raises:
            ValueError: If model_id is not recognized
        """
        if model_id == "policy":
            # Return the unwrapped policy instance for weight synchronization
            # The unwrapped policy has the same parameter structure as what's
            # extracted in the main process, avoiding key mismatches when
            # the policy is auto-wrapped (e.g., WrappablePolicy -> TensorDictModule)
            if hasattr(self, "policy") and self.policy is not None:
                return self.policy
            else:
                raise ValueError(f"No policy found for model_id '{model_id}'")
        else:
            # Try to resolve via attribute access
            if hasattr(self, model_id):
                return getattr(self, model_id)
            else:
                raise ValueError(f"Unknown model_id: {model_id}")


class _MultiDataCollector(DataCollectorBase):
    """Runs a given number of DataCollectors on separate processes.

    Args:
        create_env_fn (List[Callabled]): list of Callables, each returning an
            instance of :class:`~torchrl.envs.EnvBase`.
        policy (Callable): Policy to be executed in the environment.
            Must accept :class:`tensordict.tensordict.TensorDictBase` object as input.
            If ``None`` is provided (default), the policy used will be a
            :class:`~torchrl.collectors.RandomPolicy` instance with the environment
            ``action_spec``.
            Accepted policies are usually subclasses of :class:`~tensordict.nn.TensorDictModuleBase`.
            This is the recommended usage of the collector.
            Other callables are accepted too:
            If the policy is not a ``TensorDictModuleBase`` (e.g., a regular :class:`~torch.nn.Module`
            instances) it will be wrapped in a `nn.Module` first.
            Then, the collector will try to assess if these
            modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not.

            - If the policy forward signature matches any of ``forward(self, tensordict)``,
              ``forward(self, td)`` or ``forward(self, <anything>: TensorDictBase)`` (or
              any typing with a single argument typed as a subclass of ``TensorDictBase``)
              then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`.

            - In all other cases an attempt to wrap it will be undergone as such:
              ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.

            .. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized /
                pickled directly), the ``policy_factory`` should be used instead.

    Keyword Args:
        policy_factory (Callable[[], Callable], list of Callable[[], Callable], optional): a callable
            (or list of callables) that returns a policy instance. This is exclusive with the `policy` argument.

            .. note:: `policy_factory` comes in handy whenever the policy cannot be serialized.

            .. warning:: `policy_factory` is currently not compatible with multiprocessed data
                collectors.

        num_workers (int, optional): number of workers to use. If `create_env_fn` is a list, this will be ignored.
            Defaults to `None` (workers determined by the `create_env_fn` length).
        frames_per_batch (int, Sequence[int]): A keyword-only argument representing the
            total number of elements in a batch. If a sequence is provided, represents the number of elements in a
            batch per worker. Total number of elements in a batch is then the sum over the sequence.
        total_frames (int, optional): A keyword-only argument representing the
            total number of frames returned by the collector
            during its lifespan. If the ``total_frames`` is not divisible by
            ``frames_per_batch``, an exception is raised.
             Endless collectors can be created by passing ``total_frames=-1``.
             Defaults to ``-1`` (never ending collector).
        device (int, str or torch.device, optional): The generic device of the
            collector. The ``device`` args fills any non-specified device: if
            ``device`` is not ``None`` and any of ``storing_device``, ``policy_device`` or
            ``env_device`` is not specified, its value will be set to ``device``.
            Defaults to ``None`` (No default device).
            Supports a list of devices if one wishes to indicate a different device
            for each worker. The list must be as long as the number of workers.
        storing_device (int, str or torch.device, optional): The device on which
            the output :class:`~tensordict.TensorDict` will be stored.
            If ``device`` is passed and ``storing_device`` is ``None``, it will
            default to the value indicated by ``device``.
            For long trajectories, it may be necessary to store the data on a different
            device than the one where the policy and env are executed.
            Defaults to ``None`` (the output tensordict isn't on a specific device,
            leaf tensors sit on the device where they were created).
            Supports a list of devices if one wishes to indicate a different device
            for each worker. The list must be as long as the number of workers.
        env_device (int, str or torch.device, optional): The device on which
            the environment should be cast (or executed if that functionality is
            supported). If not specified and the env has a non-``None`` device,
            ``env_device`` will default to that value. If ``device`` is passed
            and ``env_device=None``, it will default to ``device``. If the value
            as such specified of ``env_device`` differs from ``policy_device``
            and one of them is not ``None``, the data will be cast to ``env_device``
            before being passed to the env (i.e., passing different devices to
            policy and env is supported). Defaults to ``None``.
            Supports a list of devices if one wishes to indicate a different device
            for each worker. The list must be as long as the number of workers.
        policy_device (int, str or torch.device, optional): The device on which
            the policy should be cast.
            If ``device`` is passed and ``policy_device=None``, it will default
            to ``device``. If the value as such specified of ``policy_device``
            differs from ``env_device`` and one of them is not ``None``,
            the data will be cast to ``policy_device`` before being passed to
            the policy (i.e., passing different devices to policy and env is
            supported). Defaults to ``None``.
            Supports a list of devices if one wishes to indicate a different device
            for each worker. The list must be as long as the number of workers.
        create_env_kwargs (dict, optional): A dictionary with the
            keyword arguments used to create an environment. If a list is
            provided, each of its elements will be assigned to a sub-collector.
        collector_class (Python class or constructor): a collector class to be remotely instantiated. Can be
            :class:`~torchrl.collectors.SyncDataCollector`,
            :class:`~torchrl.collectors.MultiSyncDataCollector`,
            :class:`~torchrl.collectors.MultiaSyncDataCollector`
            or a derived class of these.
            Defaults to :class:`~torchrl.collectors.SyncDataCollector`.
        max_frames_per_traj (int, optional): Maximum steps per trajectory.
            Note that a trajectory can span across multiple batches (unless
            ``reset_at_each_iter`` is set to ``True``, see below).
            Once a trajectory reaches ``n_steps``, the environment is reset.
            If the environment wraps multiple environments together, the number
            of steps is tracked for each environment independently. Negative
            values are allowed, in which case this argument is ignored.
            Defaults to ``None`` (i.e. no maximum number of steps).
        init_random_frames (int, optional): Number of frames for which the
            policy is ignored before it is called. This feature is mainly
            intended to be used in offline/model-based settings, where a
            batch of random trajectories can be used to initialize training.
            If provided, it will be rounded up to the closest multiple of frames_per_batch.
            Defaults to ``None`` (i.e. no random frames).
        reset_at_each_iter (bool, optional): Whether environments should be reset
            at the beginning of a batch collection.
            Defaults to ``False``.
        postproc (Callable, optional): A post-processing transform, such as
            a :class:`~torchrl.envs.Transform` or a :class:`~torchrl.data.postprocs.MultiStep`
            instance.
            Defaults to ``None``.
        split_trajs (bool, optional): Boolean indicating whether the resulting
            TensorDict should be split according to the trajectories.
            See :func:`~torchrl.collectors.utils.split_trajectories` for more
            information.
            Defaults to ``False``.
        exploration_type (ExplorationType, optional): interaction mode to be used when
            collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.DETERMINISTIC``,
            ``torchrl.envs.utils.ExplorationType.RANDOM``, ``torchrl.envs.utils.ExplorationType.MODE``
            or ``torchrl.envs.utils.ExplorationType.MEAN``.
        reset_when_done (bool, optional): if ``True`` (default), an environment
            that return a ``True`` value in its ``"done"`` or ``"truncated"``
            entry will be reset at the corresponding indices.
        update_at_each_batch (boolm optional): if ``True``, :meth:`update_policy_weights_()`
            will be called before (sync) or after (async) each data collection.
            Defaults to ``False``.
        preemptive_threshold (:obj:`float`, optional): a value between 0.0 and 1.0 that specifies the ratio of workers
            that will be allowed to finished collecting their rollout before the rest are forced to end early.
        num_threads (int, optional): number of threads for this process.
            Defaults to the number of workers.
        num_sub_threads (int, optional): number of threads of the subprocesses.
            Should be equal to one plus the number of processes launched within
            each subprocess (or one if a single process is launched).
            Defaults to 1 for safety: if none is indicated, launching multiple
            workers may charge the cpu load too much and harm performance.
        cat_results (str, int or None): (:class:`~torchrl.collectors.MultiSyncDataCollector` exclusively).
            If ``"stack"``, the data collected from the workers will be stacked along the
            first dimension. This is the preferred behavior as it is the most compatible
            with the rest of the library.
            If ``0``, results will be concatenated along the first dimension
            of the outputs, which can be the batched dimension if the environments are
            batched or the time dimension if not.
            A ``cat_results`` value of ``-1`` will always concatenate results along the
            time dimension. This should be preferred over the default. Intermediate values
            are also accepted.
            Defaults to ``"stack"``.

            .. note:: From v0.5, this argument will default to ``"stack"`` for a better
                interoperability with the rest of the library.

        set_truncated (bool, optional): if ``True``, the truncated signals (and corresponding
            ``"done"`` but not ``"terminated"``) will be set to ``True`` when the last frame of
            a rollout is reached. If no ``"truncated"`` key is found, an exception is raised.
            Truncated keys can be set through ``env.add_truncated_keys``.
            Defaults to ``False``.
        use_buffers (bool, optional): if ``True``, a buffer will be used to stack the data.
            This isn't compatible with environments with dynamic specs. Defaults to ``True``
            for envs without dynamic specs, ``False`` for others.
        replay_buffer (ReplayBuffer, optional): if provided, the collector will not yield tensordicts
            but populate the buffer instead. Defaults to ``None``.
        extend_buffer (bool, optional): if `True`, the replay buffer is extended with entire rollouts and not
            with single steps. Defaults to `True` for multiprocessed data collectors.
        local_init_rb (bool, optional): if ``False``, the collector will use fake data to initialize
            the replay buffer in the main process (legacy behavior). If ``True``, the storage-level
            coordination will handle initialization with real data from worker processes.
            Defaults to ``None``, which maintains backward compatibility but shows a deprecation warning.
            This parameter is deprecated and will be removed in v0.12.
        trust_policy (bool, optional): if ``True``, a non-TensorDictModule policy will be trusted to be
            assumed to be compatible with the collector. This defaults to ``True`` for CudaGraphModules
            and ``False`` otherwise.
        compile_policy (bool or Dict[str, Any], optional): if ``True``, the policy will be compiled
            using :func:`~torch.compile` default behaviour. If a dictionary of kwargs is passed, it
            will be used to compile the policy.
        cudagraph_policy (bool or Dict[str, Any], optional): if ``True``, the policy will be wrapped
            in :class:`~tensordict.nn.CudaGraphModule` with default kwargs.
            If a dictionary of kwargs is passed, it will be used to wrap the policy.
        no_cuda_sync (bool): if ``True``, explicit CUDA synchronizations calls will be bypassed.
            For environments running directly on CUDA (`IsaacLab <https://github.com/isaac-sim/IsaacLab/>`_
            or `ManiSkills <https://github.com/haosulab/ManiSkill/>`_) cuda synchronization may cause unexpected
            crashes.
            Defaults to ``False``.
        weight_updater (WeightUpdaterBase or constructor, optional): An instance of :class:`~torchrl.collectors.WeightUpdaterBase`
            or its subclass, responsible for updating the policy weights on remote inference workers.
            If not provided, a :class:`~torchrl.collectors.MultiProcessedWeightUpdater` will be used by default,
            which handles weight synchronization across multiple processes.
            Consider using a constructor if the updater needs to be serialized.
        weight_sync_schemes (dict[str, WeightSyncScheme], optional): A dictionary of weight sync schemes for the different models.
            If not provided, a :class:`~torchrl.collectors.MultiProcessWeightSyncScheme` will be used by default.
        track_policy_version (bool or PolicyVersion, optional): if ``True``, the collector will track the version of the policy.
            This will be mediated by the :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` transform, which will be added to the environment.
            Alternatively, a :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` instance can be passed, which will be used to track
            the policy version.
            Defaults to `False`.

    """

    def __init__(
        self,
        create_env_fn: Sequence[Callable[[], EnvBase]],
        policy: None
        | (TensorDictModule | Callable[[TensorDictBase], TensorDictBase]) = None,
        *,
        num_workers: int | None = None,
        policy_factory: Callable[[], Callable]
        | list[Callable[[], Callable]]
        | None = None,
        frames_per_batch: int | Sequence[int],
        total_frames: int | None = -1,
        device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None,
        storing_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None,
        env_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None,
        policy_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None,
        create_env_kwargs: Sequence[dict] | None = None,
        collector_class: type | Callable[[], DataCollectorBase] | None = None,
        max_frames_per_traj: int | None = None,
        init_random_frames: int | None = None,
        reset_at_each_iter: bool = False,
        postproc: Callable[[TensorDictBase], TensorDictBase] | None = None,
        split_trajs: bool | None = None,
        exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE,
        reset_when_done: bool = True,
        update_at_each_batch: bool = False,
        preemptive_threshold: float | None = None,
        num_threads: int | None = None,
        num_sub_threads: int = 1,
        cat_results: str | int | None = None,
        set_truncated: bool = False,
        use_buffers: bool | None = None,
        replay_buffer: ReplayBuffer | None = None,
        extend_buffer: bool = True,
        replay_buffer_chunk: bool | None = None,
        local_init_rb: bool | None = None,
        trust_policy: bool | None = None,
        compile_policy: bool | dict[str, Any] | None = None,
        cudagraph_policy: bool | dict[str, Any] | None = None,
        no_cuda_sync: bool = False,
        weight_updater: WeightUpdaterBase
        | Callable[[], WeightUpdaterBase]
        | None = None,
        weight_sync_schemes: dict[str, WeightSyncScheme] | None = None,
        track_policy_version: bool = False,
    ):
        self.closed = True

        # Set up workers and environment functions
        create_env_fn, total_frames_per_batch = self._setup_workers_and_env_fns(
            create_env_fn, num_workers, frames_per_batch
        )

        # Set up basic configuration
        self.set_truncated = set_truncated
        self.num_sub_threads = num_sub_threads
        self.num_threads = num_threads
        self.create_env_fn = create_env_fn
        self._read_compile_kwargs(compile_policy, cudagraph_policy)

        # Set up environment kwargs
        self.create_env_kwargs = self._setup_env_kwargs(create_env_kwargs)

        # Set up devices
        storing_devices, policy_devices, env_devices = self._get_devices(
            storing_device=storing_device,
            env_device=env_device,
            policy_device=policy_device,
            device=device,
        )
        self.storing_device = storing_devices
        self.policy_device = policy_devices
        self.env_device = env_devices
        self.collector_class = collector_class
        del storing_device, env_device, policy_device, device
        self.no_cuda_sync = no_cuda_sync

        # Set up replay buffer
        self._use_buffers = use_buffers
        self.replay_buffer = replay_buffer
        self._setup_multi_replay_buffer(
            local_init_rb, replay_buffer, replay_buffer_chunk, extend_buffer
        )

        # Set up policy and weights
        if trust_policy is None:
            trust_policy = policy is not None and isinstance(policy, CudaGraphModule)
        self.trust_policy = trust_policy

        policy_factory = self._setup_policy_factory(policy_factory)

        # Set up weight synchronization
        if (
            not any(policy_factory)
            and not weight_sync_schemes
            and weight_updater is None
        ):
            weight_sync_schemes = {"policy": SharedMemWeightSyncScheme()}

        self._setup_multi_policy_and_weights(
            policy, policy_factory, weight_updater, weight_sync_schemes
        )

        self._setup_multi_weight_sync(weight_updater, weight_sync_schemes)

        # Set up policy version tracking
        self._setup_multi_policy_version_tracking(track_policy_version)

        # Store policy and policy_factory
        self.policy = policy
        self.policy_factory = policy_factory

        # Set up fallback policy for weight extraction
        self._setup_fallback_policy(policy, policy_factory, weight_sync_schemes)

        # Set up total frames and other parameters
        self._setup_multi_total_frames(
            total_frames, total_frames_per_batch, frames_per_batch
        )
        self.reset_at_each_iter = reset_at_each_iter
        self.postprocs = postproc
        self.max_frames_per_traj = (
            int(max_frames_per_traj) if max_frames_per_traj is not None else 0
        )

        # Set up split trajectories
        self.requested_frames_per_batch = total_frames_per_batch
        self.reset_when_done = reset_when_done
        self._setup_split_trajs(split_trajs, reset_when_done)

        # Set up other parameters
        self.init_random_frames = (
            int(init_random_frames) if init_random_frames is not None else 0
        )
        self.update_at_each_batch = update_at_each_batch
        self.exploration_type = exploration_type
        self.frames_per_worker = np.inf

        # Set up preemptive threshold
        self._setup_preemptive_threshold(preemptive_threshold)

        # Run worker processes
        try:
            self._run_processes()
        except Exception as e:
            self.shutdown(raise_on_error=False)
            raise e

        # Set up frame tracking and other options
        self._exclude_private_keys = True
        self._frames = 0
        self._iter = -1

        # Validate cat_results
        self._validate_cat_results(cat_results)

    def _setup_workers_and_env_fns(
        self,
        create_env_fn: Sequence[Callable] | Callable,
        num_workers: int | None,
        frames_per_batch: int | Sequence[int],
    ) -> tuple[list[Callable], int]:
        """Set up workers and environment functions."""
        if isinstance(create_env_fn, Sequence):
            self.num_workers = len(create_env_fn)
        else:
            self.num_workers = num_workers
            create_env_fn = [create_env_fn] * self.num_workers

        if (
            isinstance(frames_per_batch, Sequence)
            and len(frames_per_batch) != self.num_workers
        ):
            raise ValueError(
                "If `frames_per_batch` is provided as a sequence, it should contain exactly one value per worker."
                f"Got {len(frames_per_batch)} values for {self.num_workers} workers."
            )

        self._frames_per_batch = frames_per_batch
        total_frames_per_batch = (
            sum(frames_per_batch)
            if isinstance(frames_per_batch, Sequence)
            else frames_per_batch
        )

        return create_env_fn, total_frames_per_batch

    def _setup_env_kwargs(
        self, create_env_kwargs: Sequence[dict] | dict | None
    ) -> list[dict]:
        """Set up environment kwargs for each worker."""
        if isinstance(create_env_kwargs, Mapping):
            create_env_kwargs = [create_env_kwargs] * self.num_workers
        elif create_env_kwargs is None:
            create_env_kwargs = [{}] * self.num_workers
        elif isinstance(create_env_kwargs, (tuple, list)):
            create_env_kwargs = list(create_env_kwargs)
            if len(create_env_kwargs) != self.num_workers:
                raise ValueError(
                    f"len(create_env_kwargs) must be equal to num_workers, got {len(create_env_kwargs)=} and {self.num_workers=}"
                )
        return create_env_kwargs

    def _setup_multi_replay_buffer(
        self,
        local_init_rb: bool | None,
        replay_buffer: ReplayBuffer | None,
        replay_buffer_chunk: bool | None,
        extend_buffer: bool,
    ) -> None:
        """Set up replay buffer for multi-process collector."""
        # Handle local_init_rb deprecation
        if local_init_rb is None:
            local_init_rb = False
            if replay_buffer is not None and not local_init_rb:
                warnings.warn(
                    "local_init_rb=False is deprecated and will be removed in v0.12. "
                    "The new storage-level initialization provides better performance.",
                    FutureWarning,
                )
        self.local_init_rb = local_init_rb

        self._check_replay_buffer_init()

        if replay_buffer_chunk is not None:
            if extend_buffer is None:
                replay_buffer_chunk = extend_buffer
                warnings.warn(
                    "The replay_buffer_chunk is deprecated and replaced by extend_buffer. This argument will disappear in v0.10.",
                    DeprecationWarning,
                )
            elif extend_buffer != replay_buffer_chunk:
                raise ValueError(
                    "conflicting values for replay_buffer_chunk and extend_buffer."
                )
        self.extend_buffer = extend_buffer

        if (
            replay_buffer is not None
            and hasattr(replay_buffer, "shared")
            and not replay_buffer.shared
        ):
            torchrl_logger.warning("Replay buffer is not shared. Sharing it.")
            replay_buffer.share()

    def _setup_policy_factory(
        self, policy_factory: Callable | list[Callable] | None
    ) -> list[Callable | None]:
        """Set up policy factory for each worker."""
        if not isinstance(policy_factory, Sequence):
            policy_factory = [policy_factory] * self.num_workers
        return policy_factory

    def _setup_multi_policy_and_weights(
        self,
        policy: TensorDictModule | Callable | None,
        policy_factory: list[Callable | None],
        weight_updater: WeightUpdaterBase | Callable | None,
        weight_sync_schemes: dict[str, WeightSyncScheme] | None,
    ) -> None:
        """Set up policy and extract weights for each device."""
        self._policy_weights_dict = {}
        self._fallback_policy = None  # Policy to use for weight extraction fallback

        if any(policy_factory) and policy is not None:
            raise TypeError("policy_factory and policy are mutually exclusive")
        elif not any(policy_factory):
            for policy_device, env_maker, env_maker_kwargs in _zip_strict(
                self.policy_device, self.create_env_fn, self.create_env_kwargs
            ):
                policy_new_device, get_weights_fn = self._get_policy_and_device(
                    policy=policy,
                    policy_device=policy_device,
                    env_maker=env_maker,
                    env_maker_kwargs=env_maker_kwargs,
                )
                if type(policy_new_device) is not type(policy):
                    policy = policy_new_device
                weights = (
                    TensorDict.from_module(policy_new_device)
                    if isinstance(policy_new_device, nn.Module)
                    else TensorDict()
                )
                # For multi-process collectors, ensure weights are in shared memory
                if policy_device and policy_device.type == "cpu":
                    weights = weights.share_memory_()
                self._policy_weights_dict[policy_device] = weights
                # Store the first policy instance for fallback weight extraction
                if self._fallback_policy is None:
                    self._fallback_policy = policy_new_device
            self._get_weights_fn = get_weights_fn
            if weight_updater is None:
                # For multiprocessed collectors, use MultiProcessWeightSyncScheme by default
                if weight_sync_schemes is None:
                    weight_sync_schemes = {"policy": MultiProcessWeightSyncScheme()}
        elif weight_updater is None:
            warnings.warn(
                "weight_updater is None, but policy_factory is provided. This means that the server will "
                "not know how to send the weights to the workers. If the workers can handle their weight synchronization "
                "on their own (via some specialized worker type / constructor) this may well work, but make sure "
                "your weight synchronization strategy is properly set. To suppress this warning, you can use "
                "RemoteModuleWeightUpdater() which enforces explicit weight passing when calling update_policy_weights_(weights). "
                "This will work whenever your inference and training policies are nn.Module instances with similar structures."
            )

    def _setup_multi_weight_sync(
        self,
        weight_updater: WeightUpdaterBase | Callable | None,
        weight_sync_schemes: dict[str, WeightSyncScheme] | None,
    ) -> None:
        """Set up weight synchronization for multi-process collector."""
        if weight_sync_schemes is not None:
            # Use new simplified weight synchronization system
            self._weight_sync_schemes = weight_sync_schemes
            self._weight_senders = {}
            # Senders will be created in _run_processes when pipes are available
            self.weight_updater = None  # Don't use legacy system
        else:
            # Fall back to legacy weight updater system
            self.weight_updater = weight_updater
            self._weight_sync_schemes = None
            self._weight_senders = {}

    def _setup_multi_policy_version_tracking(
        self, track_policy_version: bool | PolicyVersion
    ) -> None:
        """Set up policy version tracking for multi-process collector."""
        self.policy_version_tracker = track_policy_version
        if PolicyVersion is not None:
            if isinstance(track_policy_version, bool) and track_policy_version:
                self.policy_version_tracker = PolicyVersion()
            elif hasattr(track_policy_version, "increment_version"):
                self.policy_version_tracker = track_policy_version
            else:
                self.policy_version_tracker = None
        else:
            if track_policy_version:
                raise ImportError(
                    "PolicyVersion is not available. Please install the LLM dependencies or set track_policy_version=False."
                )
            self.policy_version_tracker = None

    def _setup_fallback_policy(
        self,
        policy: TensorDictModule | Callable | None,
        policy_factory: list[Callable | None],
        weight_sync_schemes: dict[str, WeightSyncScheme] | None,
    ) -> None:
        """Set up fallback policy for weight extraction when using policy_factory."""
        # _fallback_policy is already set in _setup_multi_policy_and_weights if a policy was provided
        # If policy_factory was used, create a policy instance to use as fallback
        if policy is None and any(policy_factory) and weight_sync_schemes is not None:
            if not hasattr(self, "_fallback_policy") or self._fallback_policy is None:
                first_factory = (
                    policy_factory[0]
                    if isinstance(policy_factory, list)
                    else policy_factory
                )
                if first_factory is not None:
                    # Create a policy instance for weight extraction
                    # This will be a reference to a policy with the same structure
                    # For shared memory, modifications to any policy will be visible here
                    self._fallback_policy = first_factory()

    def _setup_multi_total_frames(
        self,
        total_frames: int,
        total_frames_per_batch: int,
        frames_per_batch: int | Sequence[int],
    ) -> None:
        """Validate and set total frames for multi-process collector."""
        if total_frames is None or total_frames < 0:
            total_frames = float("inf")
        else:
            remainder = total_frames % total_frames_per_batch
            if remainder != 0 and rl_warnings():
                warnings.warn(
                    f"total_frames ({total_frames}) is not exactly divisible by frames_per_batch ({total_frames_per_batch}). "
                    f"This means {total_frames_per_batch - remainder} additional frames will be collected. "
                    "To silence this message, set the environment variable RL_WARNINGS to False."
                )
        self.total_frames = (
            int(total_frames) if total_frames != float("inf") else total_frames
        )

    def _setup_split_trajs(
        self, split_trajs: bool | None, reset_when_done: bool
    ) -> None:
        """Set up split trajectories option."""
        if split_trajs is None:
            split_trajs = False
        elif not reset_when_done and split_trajs:
            raise RuntimeError(
                "Cannot split trajectories when reset_when_done is False."
            )
        self.split_trajs = split_trajs

    def _setup_preemptive_threshold(self, preemptive_threshold: float | None) -> None:
        """Set up preemptive threshold for early stopping."""
        if preemptive_threshold is not None:
            if _is_osx:
                raise NotImplementedError(
                    "Cannot use preemption on OSX due to Queue.qsize() not being implemented on this platform."
                )
            self.preemptive_threshold = np.clip(preemptive_threshold, 0.0, 1.0)
            manager = _InterruptorManager()
            manager.start()
            self.interruptor = manager._Interruptor()
        else:
            self.preemptive_threshold = 1.0
            self.interruptor = None

    def _validate_cat_results(self, cat_results: str | int | None) -> None:
        """Validate cat_results parameter."""
        if cat_results is not None and (
            not isinstance(cat_results, (int, str))
            or (isinstance(cat_results, str) and cat_results != "stack")
        ):
            raise ValueError(
                "cat_results must be a string ('stack') "
                f"or an integer representing the cat dimension. Got {cat_results}."
            )
        if not isinstance(self, MultiSyncDataCollector) and cat_results not in (
            "stack",
            None,
        ):
            raise ValueError(
                "cat_results can only be used with ``MultiSyncDataCollector``."
            )
        self.cat_results = cat_results

    def _check_replay_buffer_init(self):
        if self.replay_buffer is None:
            return
        is_init = hasattr(self.replay_buffer, "_storage") and getattr(
            self.replay_buffer._storage, "initialized", True
        )
        if not is_init:
            if self.local_init_rb:
                # New behavior: storage handles all coordination itself
                # Nothing to do here - the storage will coordinate during first write
                self.replay_buffer.share()
                return

            # Legacy behavior: fake tensordict initialization
            if isinstance(self.create_env_fn[0], EnvCreator):
                fake_td = self.create_env_fn[0].meta_data.tensordict
            elif isinstance(self.create_env_fn[0], EnvBase):
                fake_td = self.create_env_fn[0].fake_tensordict()
            else:
                fake_td = self.create_env_fn[0](
                    **self.create_env_kwargs[0]
                ).fake_tensordict()
            fake_td["collector", "traj_ids"] = torch.zeros(
                fake_td.shape, dtype=torch.long
            )
            # Use extend to avoid time-related transforms to fail
            self.replay_buffer.extend(fake_td.unsqueeze(-1))
            self.replay_buffer.empty()

    @classmethod
    def _total_workers_from_env(cls, env_creators):
        if isinstance(env_creators, (tuple, list)):
            return sum(
                cls._total_workers_from_env(env_creator) for env_creator in env_creators
            )
        from torchrl.envs import ParallelEnv

        if isinstance(env_creators, ParallelEnv):
            return env_creators.num_workers
        return 1

    def _get_devices(
        self,
        *,
        storing_device: torch.device,
        policy_device: torch.device,
        env_device: torch.device,
        device: torch.device,
    ):
        # convert all devices to lists
        if not isinstance(storing_device, (list, tuple)):
            storing_device = [
                storing_device,
            ] * self.num_workers
        if not isinstance(policy_device, (list, tuple)):
            policy_device = [
                policy_device,
            ] * self.num_workers
        if not isinstance(env_device, (list, tuple)):
            env_device = [
                env_device,
            ] * self.num_workers
        if not isinstance(device, (list, tuple)):
            device = [
                device,
            ] * self.num_workers
        if not (
            len(device)
            == len(storing_device)
            == len(policy_device)
            == len(env_device)
            == self.num_workers
        ):
            raise RuntimeError(
                f"THe length of the devices does not match the number of workers: {self.num_workers}."
            )
        storing_device, policy_device, env_device = zip(
            *[
                SyncDataCollector._get_devices(
                    storing_device=storing_device,
                    policy_device=policy_device,
                    env_device=env_device,
                    device=device,
                )
                for (storing_device, policy_device, env_device, device) in zip(
                    storing_device, policy_device, env_device, device
                )
            ]
        )
        return storing_device, policy_device, env_device

    def frames_per_batch_worker(self, worker_idx: int | None = None) -> int:
        raise NotImplementedError

    @property
    def _queue_len(self) -> int:
        raise NotImplementedError

    def _run_processes(self) -> None:
        if self.num_threads is None:
            total_workers = self._total_workers_from_env(self.create_env_fn)
            self.num_threads = max(
                1, torch.get_num_threads() - total_workers
            )  # 1 more thread for this proc

        # Weight senders will be initialized after workers are ready (via init_on_sender)
        torch.set_num_threads(self.num_threads)
        queue_out = mp.Queue(self._queue_len)  # sends data from proc to main
        self.procs = []
        self.pipes = []
        self._traj_pool = _TrajectoryPool(lock=True)
        # Create a policy on the right device
        policy_factory = self.policy_factory
        if any(policy_factory):
            policy_factory = [
                CloudpickleWrapper(_policy_factory)
                for _policy_factory in policy_factory
            ]

        for i, (env_fun, env_fun_kwargs) in enumerate(
            zip(self.create_env_fn, self.create_env_kwargs)
        ):
            pipe_parent, pipe_child = mp.Pipe()  # send messages to procs
            if env_fun.__class__.__name__ != "EnvCreator" and not isinstance(
                env_fun, EnvBase
            ):  # to avoid circular imports
                env_fun = CloudpickleWrapper(env_fun)

            policy_device = self.policy_device[i]
            storing_device = self.storing_device[i]
            env_device = self.env_device[i]
            # We take the weights, the policy, and locally dispatch the weights to the policy
            #  while we send the policy to the remote process.
            #  This makes sure that a given set of shared weights for a given device are
            #  shared for all policies that rely on that device.
            policy = self.policy
            policy_weights = self._policy_weights_dict.get(policy_device)
            if policy is not None and policy_weights is not None:
                cm = policy_weights.to_module(policy)
            else:
                cm = contextlib.nullcontext()
            with cm:
                kwargs = {
                    "policy_factory": policy_factory[i],
                    "pipe_parent": pipe_parent,
                    "pipe_child": pipe_child,
                    "queue_out": queue_out,
                    "create_env_fn": env_fun,
                    "create_env_kwargs": env_fun_kwargs,
                    "policy": policy,
                    "max_frames_per_traj": self.max_frames_per_traj,
                    "frames_per_batch": self.frames_per_batch_worker(worker_idx=i),
                    "reset_at_each_iter": self.reset_at_each_iter,
                    "policy_device": policy_device,
                    "storing_device": storing_device,
                    "env_device": env_device,
                    "exploration_type": self.exploration_type,
                    "reset_when_done": self.reset_when_done,
                    "idx": i,
                    "interruptor": self.interruptor,
                    "set_truncated": self.set_truncated,
                    "use_buffers": self._use_buffers,
                    "replay_buffer": self.replay_buffer,
                    "extend_buffer": self.extend_buffer,
                    "traj_pool": self._traj_pool,
                    "trust_policy": self.trust_policy,
                    "compile_policy": self.compiled_policy_kwargs
                    if self.compiled_policy
                    else False,
                    "cudagraph_policy": self.cudagraphed_policy_kwargs
                    if self.cudagraphed_policy
                    else False,
                    "no_cuda_sync": self.no_cuda_sync,
                    "collector_class": self.collector_class,
                    "postproc": self.postprocs
                    if self.replay_buffer is not None
                    else None,
                    "weight_sync_schemes": self._weight_sync_schemes,
                }
                proc = _ProcessNoWarn(
                    target=_main_async_collector,
                    num_threads=self.num_sub_threads,
                    kwargs=kwargs,
                )
                # proc.daemon can't be set as daemonic processes may be launched by the process itself
                try:
                    proc.start()
                except TypeError as err:
                    if "cannot pickle" in str(err):
                        raise RuntimeError(
                            "A non-serializable object was passed to the collector workers."
                        ) from err
                except RuntimeError as err:
                    if "Cowardly refusing to serialize non-leaf tensor" in str(err):
                        raise RuntimeError(
                            "At least one of the tensors in the policy, replay buffer, environment constructor or postprocessor requires gradients. "
                            "This is not supported in multiprocessed data collectors.\n- For ReplayBuffer transforms, use a `transform_factory` instead with `delayed_init=True`.\n"
                            "- Make sure your environment constructor does not reference tensors already instantiated on the main process.\n"
                            "- Since no gradient can be propagated through the Collector pipes, the backward graph is never needed. Consider using detached tensors instead."
                        ) from err
                    else:
                        raise err
                except _pickle.PicklingError as err:
                    if "<lambda>" in str(err):
                        raise RuntimeError(
                            """Can't open a process with doubly cloud-pickled lambda function.
This error is likely due to an attempt to use a ParallelEnv in a
multiprocessed data collector. To do this, consider wrapping your
lambda function in an `torchrl.envs.EnvCreator` wrapper as follows:
`env = ParallelEnv(N, EnvCreator(my_lambda_function))`.
This will not only ensure that your lambda function is cloud-pickled once, but
also that the state dict is synchronised across processes if needed."""
                        ) from err
                pipe_child.close()
                self.procs.append(proc)
                self.pipes.append(pipe_parent)

        # Worker registration now handled by init_on_sender() after workers are ready
        for i, pipe_parent in enumerate(self.pipes):
            pipe_parent.poll(timeout=INSTANTIATE_TIMEOUT)
            try:
                msg = pipe_parent.recv()
            except EOFError as e:
                raise RuntimeError(
                    f"Worker {i} failed to initialize and closed the connection before sending status. "
                    f"This typically indicates that the worker process crashed during initialization. "
                    f"Check the worker process logs for the actual error."
                ) from e
            if msg != "instantiated":
                # Check if it's an error dict from worker
                if isinstance(msg, dict) and msg.get("error"):
                    # Reconstruct the exception from the worker
                    exc_type_name = msg["exception_type"]
                    exc_msg = msg["exception_msg"]
                    traceback_str = msg["traceback"]

                    # Try to get the actual exception class
                    exc_class = None
                    exc_module = msg["exception_module"]

                    if exc_module == "builtins":
                        # Get from builtins
                        import builtins

                        exc_class = getattr(builtins, exc_type_name, None)
                    else:
                        # Try to import from the module
                        try:
                            import importlib

                            mod = importlib.import_module(exc_module)
                            exc_class = getattr(mod, exc_type_name, None)
                        except Exception:
                            pass

                    # Re-raise with original exception type if possible
                    if exc_class is not None:
                        raise exc_class(
                            f"{exc_msg}\n\nWorker traceback:\n{traceback_str}"
                        )
                    else:
                        # Fall back to RuntimeError if we can't get the original type
                        raise RuntimeError(
                            f"Worker {i} raised {exc_type_name}: {exc_msg}\n\nWorker traceback:\n{traceback_str}"
                        )
                else:
                    # Legacy string error message
                    raise RuntimeError(msg)

        # Initialize all weight sync schemes now that workers are ready
        # This calls init_on_sender() for each scheme which:
        # 1. Creates transports for all workers
        # 2. Creates and configures the sender
        # 3. For SharedMemWeightSyncScheme, distributes buffer references to avoid deadlock
        if self._weight_sync_schemes:
            for model_id, scheme in self._weight_sync_schemes.items():
                # Check if scheme has new API or legacy API
                if hasattr(scheme, "init_on_sender"):
                    scheme.init_on_sender(model_id=model_id, context=self)
                    # Get the initialized sender
                    self._weight_senders[model_id] = scheme.get_sender()
                # else: keep using legacy _weight_senders initialization from before

        self.queue_out = queue_out
        self.closed = False

    _running_free = False

    def start(self):
        """Starts the collector(s) for asynchronous data collection.

        The collected data is stored in the provided replay buffer. This method initiates the background collection of
        data across multiple processes, allowing for decoupling of data collection and training.

        Raises:
            RuntimeError: If no replay buffer is defined during the collector's initialization.

        Example:
            >>> import time
            >>> from functools import partial
            >>>
            >>> import tqdm
            >>>
            >>> from torchrl.collectors import MultiaSyncDataCollector, RandomPolicy
            >>> from torchrl.data import LazyTensorStorage, ReplayBuffer
            >>> from torchrl.envs import GymEnv, set_gym_backend
            >>> import ale_py
            >>>
            >>> # Set the gym backend to gymnasium
            >>> set_gym_backend("gymnasium").set()
            >>>
            >>> if __name__ == "__main__":
            ...     # Create a random policy for the Pong environment
            ...     env_fn = partial(GymEnv, "ALE/Pong-v5")
            ...     policy = RandomPolicy(env_fn().action_spec)
            ...
            ...     # Initialize a shared replay buffer
            ...     rb = ReplayBuffer(storage=LazyTensorStorage(10000), shared=True)
            ...
            ...     # Create a multi-async data collector with 16 environments
            ...     num_envs = 16
            ...     collector = MultiaSyncDataCollector(
            ...         [env_fn] * num_envs,
            ...         policy=policy,
            ...         replay_buffer=rb,
            ...         frames_per_batch=num_envs * 16,
            ...         total_frames=-1,
            ...     )
            ...
            ...     # Progress bar to track the number of collected frames
            ...     pbar = tqdm.tqdm(total=100_000)
            ...
            ...     # Start the collector asynchronously
            ...     collector.start()
            ...
            ...     # Track the write count of the replay buffer
            ...     prec_wc = 0
            ...     while True:
            ...         wc = rb.write_count
            ...         c = wc - prec_wc
            ...         prec_wc = wc
            ...
            ...         # Update the progress bar
            ...         pbar.update(c)
            ...         pbar.set_description(f"Write Count: {rb.write_count}")
            ...
            ...         # Check the write count every 0.5 seconds
            ...         time.sleep(0.5)
            ...
            ...         # Stop when the desired number of frames is reached
            ...         if rb.write_count . 100_000:
            ...             break
            ...
            ...     # Shut down the collector
            ...     collector.async_shutdown()
        """
        if self.replay_buffer is None:
            raise RuntimeError("Replay buffer must be defined for execution.")
        if self.init_random_frames is not None and self.init_random_frames > 0:
            raise RuntimeError(
                "Cannot currently start() a collector that requires random frames. Please submit a feature request on github."
            )
        self._running_free = True
        for pipe in self.pipes:
            pipe.send((None, "run_free"))

    @contextlib.contextmanager
    def pause(self):
        """Context manager that pauses the collector if it is running free."""
        if self._running_free:
            for pipe in self.pipes:
                pipe.send((None, "pause"))
            # Make sure all workers are paused
            for _ in self.pipes:
                idx, msg = self.queue_out.get()
                if msg != "paused":
                    raise ValueError(f"Expected paused, but got {msg=}.")
                torchrl_logger.info(f"Worker {idx} is paused.")
            self._running_free = False
            yield None
            for pipe in self.pipes:
                pipe.send((None, "restart"))
            self._running_free = True
        else:
            raise RuntimeError("Collector cannot be paused.")

    def __del__(self):
        try:
            self.shutdown()
        except Exception:
            # an AttributeError will typically be raised if the collector is deleted when the program ends.
            # In the future, insignificant changes to the close method may change the error type.
            # We excplicitely assume that any error raised during closure in
            # __del__ will not affect the program.
            pass

    def shutdown(
        self,
        timeout: float | None = None,
        close_env: bool = True,
        raise_on_error: bool = True,
    ) -> None:
        """Shuts down all processes. This operation is irreversible.

        Args:
            timeout (float, optional): The timeout for closing pipes between workers.
            close_env (bool, optional): Whether to close the environment. Defaults to `True`.
            raise_on_error (bool, optional): Whether to raise an error if the shutdown fails. Defaults to `True`.
        """
        if not close_env:
            raise RuntimeError(
                f"Cannot shutdown {type(self).__name__} collector without environment being closed."
            )
        try:
            self._shutdown_main(timeout)
        except Exception as e:
            if raise_on_error:
                raise e
            else:
                pass

    def _shutdown_main(self, timeout: float | None = None) -> None:
        if timeout is None:
            timeout = 10
        try:
            if self.closed:
                return
            _check_for_faulty_process(self.procs)
            all_closed = [False] * self.num_workers
            rep = 0
            for idx in range(self.num_workers):
                if all_closed[idx]:
                    continue
                if not self.procs[idx].is_alive():
                    continue
                self.pipes[idx].send((None, "close"))

            while not all(all_closed) and rep < 1000:
                rep += 1
                for idx in range(self.num_workers):
                    if all_closed[idx]:
                        continue
                    if not self.procs[idx].is_alive():
                        all_closed[idx] = True
                        continue
                    try:
                        if self.pipes[idx].poll(timeout / 1000 / self.num_workers):
                            msg = self.pipes[idx].recv()
                            if msg != "closed":
                                raise RuntimeError(f"got {msg} but expected 'close'")
                            all_closed[idx] = True
                        else:
                            continue
                    except BrokenPipeError:
                        all_closed[idx] = True
                        continue
            self.closed = True

            self.queue_out.close()
            for pipe in self.pipes:
                pipe.close()
            for proc in self.procs:
                proc.join(1.0)
        finally:
            import torchrl

            num_threads = min(
                torchrl._THREAD_POOL_INIT,
                torch.get_num_threads()
                + self._total_workers_from_env(self.create_env_fn),
            )
            torch.set_num_threads(num_threads)

            for proc in self.procs:
                if proc.is_alive():
                    proc.terminate()

    def async_shutdown(self, timeout: float | None = None):
        return self.shutdown(timeout=timeout)

    def set_seed(self, seed: int, static_seed: bool = False) -> int:
        """Sets the seeds of the environments stored in the DataCollector.

        Args:
            seed: integer representing the seed to be used for the environment.
            static_seed (bool, optional): if ``True``, the seed is not incremented.
                Defaults to False

        Returns:
            Output seed. This is useful when more than one environment is
            contained in the DataCollector, as the seed will be incremented for
            each of these. The resulting seed is the seed of the last
            environment.

        Examples:
            >>> from torchrl.envs import ParallelEnv
            >>> from torchrl.envs.libs.gym import GymEnv
            >>> from tensordict.nn import TensorDictModule
            >>> from torch import nn
            >>> env_fn = lambda: GymEnv("Pendulum-v1")
            >>> env_fn_parallel = lambda: ParallelEnv(6, env_fn)
            >>> policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"])
            >>> collector = SyncDataCollector(env_fn_parallel, policy, frames_per_batch=100, total_frames=300)
            >>> out_seed = collector.set_seed(1)  # out_seed = 6

        """
        _check_for_faulty_process(self.procs)
        for idx in range(self.num_workers):
            self.pipes[idx].send(((seed, static_seed), "seed"))
            new_seed, msg = self.pipes[idx].recv()
            if msg != "seeded":
                raise RuntimeError(f"Expected msg='seeded', got {msg}")
            seed = new_seed
        self.reset()
        return seed

    def reset(self, reset_idx: Sequence[bool] | None = None) -> None:
        """Resets the environments to a new initial state.

        Args:
            reset_idx: Optional. Sequence indicating which environments have
                to be reset. If None, all environments are reset.

        """
        _check_for_faulty_process(self.procs)

        if reset_idx is None:
            reset_idx = [True for _ in range(self.num_workers)]
        for idx in range(self.num_workers):
            if reset_idx[idx]:
                self.pipes[idx].send((None, "reset"))
        for idx in range(self.num_workers):
            if reset_idx[idx]:
                j, msg = self.pipes[idx].recv()
                if msg != "reset":
                    raise RuntimeError(f"Expected msg='reset', got {msg}")

    def state_dict(self) -> OrderedDict:
        """Returns the state_dict of the data collector.

        Each field represents a worker containing its own state_dict.

        """
        for idx in range(self.num_workers):
            self.pipes[idx].send((None, "state_dict"))
        state_dict = OrderedDict()
        for idx in range(self.num_workers):
            _state_dict, msg = self.pipes[idx].recv()
            if msg != "state_dict":
                raise RuntimeError(f"Expected msg='state_dict', got {msg}")
            state_dict[f"worker{idx}"] = _state_dict
        state_dict.update({"frames": self._frames, "iter": self._iter})

        return state_dict

    def load_state_dict(self, state_dict: OrderedDict) -> None:
        """Loads the state_dict on the workers.

        Args:
            state_dict (OrderedDict): state_dict of the form
                ``{"worker0": state_dict0, "worker1": state_dict1}``.

        """
        for idx in range(self.num_workers):
            self.pipes[idx].send((state_dict[f"worker{idx}"], "load_state_dict"))
        for idx in range(self.num_workers):
            _, msg = self.pipes[idx].recv()
            if msg != "loaded":
                raise RuntimeError(f"Expected msg='loaded', got {msg}")
        self._frames = state_dict["frames"]
        self._iter = state_dict["iter"]

    def increment_version(self):
        """Increment the policy version."""
        if self.policy_version_tracker is not None:
            if not hasattr(self.policy_version_tracker, "increment_version"):
                raise RuntimeError(
                    "Policy version tracker is not a PolicyVersion instance. Please pass a PolicyVersion instance to the collector."
                )
            self.policy_version_tracker.increment_version()

    @property
    def policy_version(self) -> str | int | None:
        """The current policy version."""
        if not hasattr(self.policy_version_tracker, "version"):
            return None
        return self.policy_version_tracker.version

    def get_policy_version(self) -> str | int | None:
        """Get the current policy version.

        This method exists to support remote calls in Ray actors, since properties
        cannot be accessed directly through Ray's RPC mechanism.

        Returns:
            The current version number (int) or UUID (str), or None if version tracking is disabled.
        """
        return self.policy_version

    def getattr_policy(self, attr):
        """Get an attribute from the policy of the first worker.

        Args:
            attr (str): The attribute name to retrieve from the policy.

        Returns:
            The attribute value from the policy of the first worker.

        Raises:
            AttributeError: If the attribute doesn't exist on the policy.
        """
        _check_for_faulty_process(self.procs)

        # Send command to first worker (index 0)
        self.pipes[0].send((attr, "getattr_policy"))
        result, msg = self.pipes[0].recv()
        if msg != "getattr_policy":
            raise RuntimeError(f"Expected msg='getattr_policy', got {msg}")

        # If the worker returned an AttributeError, re-raise it
        if isinstance(result, AttributeError):
            raise result

        return result

    def getattr_env(self, attr):
        """Get an attribute from the environment of the first worker.

        Args:
            attr (str): The attribute name to retrieve from the environment.

        Returns:
            The attribute value from the environment of the first worker.

        Raises:
            AttributeError: If the attribute doesn't exist on the environment.
        """
        _check_for_faulty_process(self.procs)

        # Send command to first worker (index 0)
        self.pipes[0].send((attr, "getattr_env"))
        result, msg = self.pipes[0].recv()
        if msg != "getattr_env":
            raise RuntimeError(f"Expected msg='getattr_env', got {msg}")

        # If the worker returned an AttributeError, re-raise it
        if isinstance(result, AttributeError):
            raise result

        return result

    def getattr_rb(self, attr):
        """Get an attribute from the replay buffer."""
        return getattr(self.replay_buffer, attr)

    def get_model(self, model_id: str):
        """Get model instance by ID (for weight sync schemes).

        Args:
            model_id: Model identifier (e.g., "policy", "value_net")

        Returns:
            The model instance

        Raises:
            ValueError: If model_id is not recognized
        """
        if model_id == "policy":
            # Return the fallback policy instance
            if hasattr(self, "_fallback_policy") and self._fallback_policy is not None:
                return self._fallback_policy
            elif hasattr(self, "policy") and self.policy is not None:
                return self.policy
            else:
                raise ValueError(f"No policy found for model_id '{model_id}'")
        else:
            # Try to resolve via attribute access
            if hasattr(self, model_id):
                return getattr(self, model_id)
            else:
                raise ValueError(f"Unknown model_id: {model_id}")

    def get_cached_weights(self, model_id: str):
        """Get cached shared memory weights if available (for weight sync schemes).

        Args:
            model_id: Model identifier

        Returns:
            Cached TensorDict weights or None if not available
        """
        if model_id == "policy" and hasattr(self, "_policy_weights_dict"):
            # Get the policy device (first device if list)
            policy_device = self.policy_device
            if isinstance(policy_device, (list, tuple)):
                policy_device = policy_device[0] if len(policy_device) > 0 else None

            # Return cached weights for this device
            return self._policy_weights_dict.get(policy_device)
        return None


@accept_remote_rref_udf_invocation
class MultiSyncDataCollector(_MultiDataCollector):
    """Runs a given number of DataCollectors on separate processes synchronously.

    .. aafig::

            +----------------------------------------------------------------------+
            |            "MultiSyncDataCollector"                 |                |
            |~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~|                |
            |   "Collector 1" |  "Collector 2"  |  "Collector 3"  |     Main       |
            |~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~|
            | "env1" | "env2" | "env3" | "env4" | "env5" | "env6" |                |
            |~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~~~~~~~~~|
            |"reset" |"reset" |"reset" |"reset" |"reset" |"reset" |                |
            |        |        |        |        |        |        |                |
            |       "actor"   |        |        |       "actor"   |                |
            |                 |        |        |                 |                |
            | "step" | "step" |       "actor"   |                 |                |
            |        |        |                 |                 |                |
            |        |        |                 | "step" | "step" |                |
            |        |        |                 |        |        |                |
            |       "actor"   | "step" | "step" |       "actor"   |                |
            |                 |        |        |                 |                |
            |                 |       "actor"   |                 |                |
            |                 |                 |                 |                |
            |                       "yield batch of traj 1"------->"collect, train"|
            |                                                     |                |
            | "step" | "step" | "step" | "step" | "step" | "step" |                |
            |        |        |        |        |        |        |                |
            |       "actor"   |       "actor"   |        |        |                |
            |                 | "step" | "step" |       "actor"   |                |
            |                 |        |        |                 |                |
            | "step" | "step" |       "actor"   | "step" | "step" |                |
            |        |        |                 |        |        |                |
            |       "actor"   |                 |       "actor"   |                |
            |                       "yield batch of traj 2"------->"collect, train"|
            |                                                     |                |
            +----------------------------------------------------------------------+

    Envs can be identical or different.

    The collection starts when the next item of the collector is queried,
    and no environment step is computed in between the reception of a batch of
    trajectory and the start of the next collection.
    This class can be safely used with online RL sota-implementations.

    .. note::
        Python requires multiprocessed code to be instantiated within a main guard:

            >>> from torchrl.collectors import MultiSyncDataCollector
            >>> if __name__ == "__main__":
            ...     # Create your collector here
            ...     collector = MultiSyncDataCollector(...)

        See https://docs.python.org/3/library/multiprocessing.html for more info.

    Examples:
        >>> from torchrl.envs.libs.gym import GymEnv
        >>> from tensordict.nn import TensorDictModule
        >>> from torch import nn
        >>> from torchrl.collectors import MultiSyncDataCollector
        >>> if __name__ == "__main__":
        ...     env_maker = lambda: GymEnv("Pendulum-v1", device="cpu")
        ...     policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"])
        ...     collector = MultiSyncDataCollector(
        ...         create_env_fn=[env_maker, env_maker],
        ...         policy=policy,
        ...         total_frames=2000,
        ...         max_frames_per_traj=50,
        ...         frames_per_batch=200,
        ...         init_random_frames=-1,
        ...         reset_at_each_iter=False,
        ...         device="cpu",
        ...         storing_device="cpu",
        ...         cat_results="stack",
        ...     )
        ...     for i, data in enumerate(collector):
        ...         if i == 2:
        ...             print(data)
        ...             break
        ...     collector.shutdown()
        ...     del collector
        TensorDict(
            fields={
                action: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                collector: TensorDict(
                    fields={
                        traj_ids: Tensor(shape=torch.Size([200]), device=cpu, dtype=torch.int64, is_shared=False)},
                    batch_size=torch.Size([200]),
                    device=cpu,
                    is_shared=False),
                done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                next: TensorDict(
                    fields={
                        done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                        observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                        reward: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                        step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False),
                        truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
                    batch_size=torch.Size([200]),
                    device=cpu,
                    is_shared=False),
                observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False),
                truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
            batch_size=torch.Size([200]),
            device=cpu,
            is_shared=False)

    """

    __doc__ += _MultiDataCollector.__doc__

    # for RPC
    def next(self):
        return super().next()

    # for RPC
    def shutdown(
        self,
        timeout: float | None = None,
        close_env: bool = True,
        raise_on_error: bool = True,
    ) -> None:
        if not close_env:
            raise RuntimeError(
                f"Cannot shutdown {type(self).__name__} collector without environment being closed."
            )
        if hasattr(self, "out_buffer"):
            del self.out_buffer
        if hasattr(self, "buffers"):
            del self.buffers
        try:
            return super().shutdown(timeout=timeout)
        except Exception as e:
            if raise_on_error:
                raise e
            else:
                pass

    # for RPC
    def set_seed(self, seed: int, static_seed: bool = False) -> int:
        return super().set_seed(seed, static_seed)

    # for RPC
    def state_dict(self) -> OrderedDict:
        return super().state_dict()

    # for RPC
    def load_state_dict(self, state_dict: OrderedDict) -> None:
        return super().load_state_dict(state_dict)

    # for RPC
    def update_policy_weights_(
        self,
        policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None,
        *,
        worker_ids: int | list[int] | torch.device | list[torch.device] | None = None,
        **kwargs,
    ) -> None:
        if "policy_weights" in kwargs:
            warnings.warn(
                "`policy_weights` is deprecated. Use `policy_or_weights` instead.",
                DeprecationWarning,
            )
            policy_or_weights = kwargs.pop("policy_weights")

        super().update_policy_weights_(
            policy_or_weights=policy_or_weights, worker_ids=worker_ids, **kwargs
        )

    def frames_per_batch_worker(self, worker_idx: int | None) -> int:
        if worker_idx is not None and isinstance(self._frames_per_batch, Sequence):
            return self._frames_per_batch[worker_idx]
        if self.requested_frames_per_batch % self.num_workers != 0 and rl_warnings():
            warnings.warn(
                f"frames_per_batch {self.requested_frames_per_batch} is not exactly divisible by the number of collector workers {self.num_workers},"
                f" this results in more frames_per_batch per iteration that requested."
                "To silence this message, set the environment variable RL_WARNINGS to False."
            )
        frames_per_batch_worker = -(
            -self.requested_frames_per_batch // self.num_workers
        )
        return frames_per_batch_worker

    @property
    def _queue_len(self) -> int:
        return self.num_workers

    def iterator(self) -> Iterator[TensorDictBase]:
        cat_results = self.cat_results
        if cat_results is None:
            cat_results = "stack"

        self.buffers = {}
        dones = [False for _ in range(self.num_workers)]
        workers_frames = [0 for _ in range(self.num_workers)]
        same_device = None
        self.out_buffer = None
        preempt = self.interruptor is not None and self.preemptive_threshold < 1.0

        while not all(dones) and self._frames < self.total_frames:
            _check_for_faulty_process(self.procs)
            if self.update_at_each_batch:
                self.update_policy_weights_()

            for idx in range(self.num_workers):
                if (
                    self.init_random_frames is not None
                    and self._frames < self.init_random_frames
                ):
                    msg = "continue_random"
                else:
                    msg = "continue"
                # Debug: sending 'continue'
                self.pipes[idx].send((None, msg))

            self._iter += 1

            if preempt:
                self.interruptor.start_collection()
                while self.queue_out.qsize() < int(
                    self.num_workers * self.preemptive_threshold
                ):
                    continue
                self.interruptor.stop_collection()
                # Now wait for stragglers to return
                while self.queue_out.qsize() < int(self.num_workers):
                    continue

            recv = collections.deque()
            t0 = time.time()
            while len(recv) < self.num_workers and (
                (time.time() - t0) < (_TIMEOUT * _MAX_IDLE_COUNT)
            ):
                for _ in range(self.num_workers):
                    try:
                        new_data, j = self.queue_out.get(timeout=_TIMEOUT)
                        recv.append((new_data, j))
                    except (TimeoutError, Empty):
                        _check_for_faulty_process(self.procs)
            if (time.time() - t0) > (_TIMEOUT * _MAX_IDLE_COUNT):
                try:
                    self.shutdown()
                finally:
                    raise RuntimeError(
                        f"Failed to gather all collector output within {_TIMEOUT * _MAX_IDLE_COUNT} seconds. "
                        f"Increase the MAX_IDLE_COUNT environment variable to bypass this error."
                    )

            for _ in range(self.num_workers):
                new_data, j = recv.popleft()
                use_buffers = self._use_buffers
                if self.replay_buffer is not None:
                    idx = new_data
                    workers_frames[idx] = workers_frames[
                        idx
                    ] + self.frames_per_batch_worker(worker_idx=idx)
                    continue
                elif j == 0 or not use_buffers:
                    try:
                        data, idx = new_data
                        self.buffers[idx] = data
                        if use_buffers is None and j > 0:
                            self._use_buffers = False
                    except TypeError:
                        if use_buffers is None:
                            self._use_buffers = True
                            idx = new_data
                        else:
                            raise
                else:
                    idx = new_data

                if preempt:
                    # mask buffers if cat, and create a mask if stack
                    if cat_results != "stack":
                        buffers = {}
                        for worker_idx, buffer in self.buffers.items():
                            valid = buffer.get(("collector", "traj_ids")) != -1
                            if valid.ndim > 2:
                                valid = valid.flatten(0, -2)
                            if valid.ndim == 2:
                                valid = valid.any(0)
                            buffers[worker_idx] = buffer[..., valid]
                    else:
                        for buffer in self.buffers.values():
                            with buffer.unlock_():
                                buffer.set(
                                    ("collector", "mask"),
                                    buffer.get(("collector", "traj_ids")) != -1,
                                )
                        buffers = self.buffers
                else:
                    buffers = self.buffers

                # Skip frame counting if this worker didn't send data this iteration
                # (happens when reusing buffers or on first iteration with some workers)
                if idx not in buffers:
                    continue

                workers_frames[idx] = workers_frames[idx] + buffers[idx].numel()

                if workers_frames[idx] >= self.total_frames:
                    dones[idx] = True

            if self.replay_buffer is not None:
                yield
                self._frames += sum(
                    [
                        self.frames_per_batch_worker(worker_idx)
                        for worker_idx in range(self.num_workers)
                    ]
                )
                continue

            # we have to correct the traj_ids to make sure that they don't overlap
            # We can count the number of frames collected for free in this loop
            n_collected = 0
            for idx in buffers.keys():
                buffer = buffers[idx]
                traj_ids = buffer.get(("collector", "traj_ids"))
                if preempt:
                    if cat_results == "stack":
                        mask_frames = buffer.get(("collector", "traj_ids")) != -1
                        n_collected += mask_frames.sum().cpu()
                    else:
                        n_collected += traj_ids.numel()
                else:
                    n_collected += traj_ids.numel()

            if same_device is None:
                prev_device = None
                same_device = True
                for item in self.buffers.values():
                    if prev_device is None:
                        prev_device = item.device
                    else:
                        same_device = same_device and (item.device == prev_device)

            if cat_results == "stack":
                stack = (
                    torch.stack if self._use_buffers else TensorDict.maybe_dense_stack
                )
                if same_device:
                    self.out_buffer = stack(list(buffers.values()), 0)
                else:
                    self.out_buffer = stack(
                        [item.cpu() for item in buffers.values()], 0
                    )
            else:
                if self._use_buffers is None:
                    torchrl_logger.warning(
                        "use_buffer not specified and not yet inferred from data, assuming `True`."
                    )
                elif not self._use_buffers:
                    raise RuntimeError(
                        "Cannot concatenate results with use_buffers=False"
                    )
                try:
                    if same_device:
                        self.out_buffer = torch.cat(list(buffers.values()), cat_results)
                    else:
                        self.out_buffer = torch.cat(
                            [item.cpu() for item in buffers.values()], cat_results
                        )
                except RuntimeError as err:
                    if (
                        preempt
                        and cat_results != -1
                        and "Sizes of tensors must match" in str(err)
                    ):
                        raise RuntimeError(
                            "The value provided to cat_results isn't compatible with the collectors outputs. "
                            "Consider using `cat_results=-1`."
                        )
                    raise

            # TODO: why do we need to do cat inplace and clone?
            if self.split_trajs:
                out = split_trajectories(self.out_buffer, prefix="collector")
            else:
                out = self.out_buffer
            if cat_results in (-1, "stack"):
                out.refine_names(*[None] * (out.ndim - 1) + ["time"])

            self._frames += n_collected

            if self.postprocs:
                self.postprocs = (
                    self.postprocs.to(out.device)
                    if hasattr(self.postprocs, "to")
                    else self.postprocs
                )
                out = self.postprocs(out)
            if self._exclude_private_keys:
                excluded_keys = [key for key in out.keys() if key.startswith("_")]
                if excluded_keys:
                    out = out.exclude(*excluded_keys)
            yield out
            del out

        del self.buffers
        self.out_buffer = None
        # We shall not call shutdown just yet as user may want to retrieve state_dict
        # self._shutdown_main()


@accept_remote_rref_udf_invocation
class MultiaSyncDataCollector(_MultiDataCollector):
    """Runs a given number of DataCollectors on separate processes asynchronously.

    .. aafig::


            +----------------------------------------------------------------------+
            |           "MultiConcurrentCollector"                |                |
            |~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~|                |
            |  "Collector 1"  |  "Collector 2"  |  "Collector 3"  |     "Main"     |
            |~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~|
            | "env1" | "env2" | "env3" | "env4" | "env5" | "env6" |                |
            |~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~~~~~~~~~|
            |"reset" |"reset" |"reset" |"reset" |"reset" |"reset" |                |
            |        |        |        |        |        |        |                |
            |       "actor"   |        |        |       "actor"   |                |
            |                 |        |        |                 |                |
            | "step" | "step" |       "actor"   |                 |                |
            |        |        |                 |                 |                |
            |        |        |                 | "step" | "step" |                |
            |        |        |                 |        |        |                |
            |       "actor    | "step" | "step" |       "actor"   |                |
            |                 |        |        |                 |                |
            | "yield batch 1" |       "actor"   |                 |"collect, train"|
            |                 |                 |                 |                |
            | "step" | "step" |                 | "yield batch 2" |"collect, train"|
            |        |        |                 |                 |                |
            |        |        | "yield batch 3" |                 |"collect, train"|
            |        |        |                 |                 |                |
            +----------------------------------------------------------------------+

    Environment types can be identical or different.

    The collection keeps on occurring on all processes even between the time
    the batch of rollouts is collected and the next call to the iterator.
    This class can be safely used with offline RL sota-implementations.

    .. note:: Python requires multiprocessed code to be instantiated within a main guard:

            >>> from torchrl.collectors import MultiaSyncDataCollector
            >>> if __name__ == "__main__":
            ...     # Create your collector here

        See https://docs.python.org/3/library/multiprocessing.html for more info.

    Examples:
        >>> from torchrl.envs.libs.gym import GymEnv
        >>> from tensordict.nn import TensorDictModule
        >>> from torch import nn
        >>> from torchrl.collectors import MultiaSyncDataCollector
        >>> if __name__ == "__main__":
        ...     env_maker = lambda: GymEnv("Pendulum-v1", device="cpu")
        ...     policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"])
        ...     collector = MultiaSyncDataCollector(
        ...         create_env_fn=[env_maker, env_maker],
        ...         policy=policy,
        ...         total_frames=2000,
        ...         max_frames_per_traj=50,
        ...         frames_per_batch=200,
        ...         init_random_frames=-1,
        ...         reset_at_each_iter=False,
        ...         device="cpu",
        ...         storing_device="cpu",
        ...         cat_results="stack",
        ...     )
        ...     for i, data in enumerate(collector):
        ...         if i == 2:
        ...             print(data)
        ...             break
        ...     collector.shutdown()
        ...     del collector
        TensorDict(
            fields={
                action: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                collector: TensorDict(
                    fields={
                        traj_ids: Tensor(shape=torch.Size([200]), device=cpu, dtype=torch.int64, is_shared=False)},
                    batch_size=torch.Size([200]),
                    device=cpu,
                    is_shared=False),
                done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                next: TensorDict(
                    fields={
                        done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                        observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                        reward: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                        step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False),
                        truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
                    batch_size=torch.Size([200]),
                    device=cpu,
                    is_shared=False),
                observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False),
                truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
            batch_size=torch.Size([200]),
            device=cpu,
            is_shared=False)

    """

    __doc__ += _MultiDataCollector.__doc__

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.out_tensordicts = defaultdict(lambda: None)
        self.running = False

        if self.postprocs is not None and self.replay_buffer is None:
            postproc = self.postprocs
            self.postprocs = {}
            for _device in self.storing_device:
                if _device not in self.postprocs:
                    if hasattr(postproc, "to"):
                        postproc = deepcopy(postproc).to(_device)
                    self.postprocs[_device] = postproc

    # for RPC
    def next(self):
        return super().next()

    # for RPC
    def shutdown(
        self,
        timeout: float | None = None,
        close_env: bool = True,
        raise_on_error: bool = True,
    ) -> None:
        if hasattr(self, "out_tensordicts"):
            del self.out_tensordicts
        if not close_env:
            raise RuntimeError(
                f"Cannot shutdown {type(self).__name__} collector without environment being closed."
            )
        return super().shutdown(timeout=timeout, raise_on_error=raise_on_error)

    # for RPC
    def set_seed(self, seed: int, static_seed: bool = False) -> int:
        return super().set_seed(seed, static_seed)

    # for RPC
    def state_dict(self) -> OrderedDict:
        return super().state_dict()

    # for RPC
    def load_state_dict(self, state_dict: OrderedDict) -> None:
        return super().load_state_dict(state_dict)

    # for RPC
    def update_policy_weights_(
        self,
        policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None,
        *,
        worker_ids: int | list[int] | torch.device | list[torch.device] | None = None,
        **kwargs,
    ) -> None:
        if "policy_weights" in kwargs:
            warnings.warn(
                "`policy_weights` is deprecated. Use `policy_or_weights` instead.",
                DeprecationWarning,
            )
            policy_or_weights = kwargs.pop("policy_weights")

        super().update_policy_weights_(
            policy_or_weights=policy_or_weights, worker_ids=worker_ids, **kwargs
        )

    def frames_per_batch_worker(self, worker_idx: int | None = None) -> int:
        return self.requested_frames_per_batch

    def _get_from_queue(self, timeout=None) -> tuple[int, int, TensorDictBase]:
        new_data, j = self.queue_out.get(timeout=timeout)
        use_buffers = self._use_buffers
        if self.replay_buffer is not None:
            idx = new_data
        elif j == 0 or not use_buffers:
            try:
                data, idx = new_data
                self.out_tensordicts[idx] = data
                if use_buffers is None and j > 0:
                    use_buffers = self._use_buffers = False
            except TypeError:
                if use_buffers is None:
                    use_buffers = self._use_buffers = True
                    idx = new_data
                else:
                    raise
        else:
            idx = new_data
        out = self.out_tensordicts[idx]
        if not self.replay_buffer and (j == 0 or use_buffers):
            # we clone the data to make sure that we'll be working with a fixed copy
            out = out.clone()
        return idx, j, out

    @property
    def _queue_len(self) -> int:
        return 1

    def iterator(self) -> Iterator[TensorDictBase]:
        if self.update_at_each_batch:
            self.update_policy_weights_()

        for i in range(self.num_workers):
            if self.init_random_frames is not None and self.init_random_frames > 0:
                self.pipes[i].send((None, "continue_random"))
            else:
                self.pipes[i].send((None, "continue"))
        self.running = True

        workers_frames = [0 for _ in range(self.num_workers)]
        while self._frames < self.total_frames:
            self._iter += 1
            counter = 0
            while True:
                try:
                    idx, j, out = self._get_from_queue(timeout=_TIMEOUT)
                    break
                except (TimeoutError, Empty):
                    counter += _TIMEOUT
                    _check_for_faulty_process(self.procs)
                if counter > (_TIMEOUT * _MAX_IDLE_COUNT):
                    raise RuntimeError(
                        f"Failed to gather all collector output within {_TIMEOUT * _MAX_IDLE_COUNT} seconds. "
                        f"Increase the MAX_IDLE_COUNT environment variable to bypass this error."
                    )
            if self.replay_buffer is None:
                worker_frames = out.numel()
                if self.split_trajs:
                    out = split_trajectories(out, prefix="collector")
            else:
                worker_frames = self.frames_per_batch_worker()
            self._frames += worker_frames
            workers_frames[idx] = workers_frames[idx] + worker_frames
            if out is not None and self.postprocs:
                out = self.postprocs[out.device](out)

            # the function blocks here until the next item is asked, hence we send the message to the
            # worker to keep on working in the meantime before the yield statement
            if (
                self.init_random_frames is not None
                and self._frames < self.init_random_frames
            ):
                msg = "continue_random"
            else:
                msg = "continue"
            self.pipes[idx].send((idx, msg))
            if out is not None and self._exclude_private_keys:
                excluded_keys = [key for key in out.keys() if key.startswith("_")]
                out = out.exclude(*excluded_keys)
            yield out

        # We don't want to shutdown yet, the user may want to call state_dict before
        # self._shutdown_main()
        self.running = False

    def _shutdown_main(self, *args, **kwargs) -> None:
        if hasattr(self, "out_tensordicts"):
            del self.out_tensordicts
        return super()._shutdown_main(*args, **kwargs)

    def reset(self, reset_idx: Sequence[bool] | None = None) -> None:
        super().reset(reset_idx)
        if self.queue_out.full():
            time.sleep(_TIMEOUT)  # wait until queue is empty
        if self.queue_out.full():
            raise Exception("self.queue_out is full")
        if self.running:
            for idx in range(self.num_workers):
                if (
                    self.init_random_frames is not None
                    and self._frames < self.init_random_frames
                ):
                    self.pipes[idx].send((idx, "continue_random"))
                else:
                    self.pipes[idx].send((idx, "continue"))


@accept_remote_rref_udf_invocation
class aSyncDataCollector(MultiaSyncDataCollector):
    """Runs a single DataCollector on a separate process.

    This is mostly useful for offline RL paradigms where the policy being
    trained can differ from the policy used to collect data. In online
    settings, a regular DataCollector should be preferred. This class is
    merely a wrapper around a MultiaSyncDataCollector where a single process
    is being created.

    Args:
        create_env_fn (Callabled): Callable returning an instance of EnvBase
        policy (Callable): Policy to be executed in the environment.
            Must accept :class:`tensordict.tensordict.TensorDictBase` object as input.
            If ``None`` is provided, the policy used will be a
            :class:`~torchrl.collectors.RandomPolicy` instance with the environment
            ``action_spec``.
            Accepted policies are usually subclasses of :class:`~tensordict.nn.TensorDictModuleBase`.
            This is the recommended usage of the collector.
            Other callables are accepted too:
            If the policy is not a ``TensorDictModuleBase`` (e.g., a regular :class:`~torch.nn.Module`
            instances) it will be wrapped in a `nn.Module` first.
            Then, the collector will try to assess if these
            modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not.

            - If the policy forward signature matches any of ``forward(self, tensordict)``,
              ``forward(self, td)`` or ``forward(self, <anything>: TensorDictBase)`` (or
              any typing with a single argument typed as a subclass of ``TensorDictBase``)
              then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`.

            - In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.

            .. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized /
                pickled directly), the ``policy_factory`` should be used instead.

    Keyword Args:
        policy_factory (Callable[[], Callable], optional): a callable that returns
            a policy instance. This is exclusive with the `policy` argument.

            .. note:: `policy_factory` comes in handy whenever the policy cannot be serialized.

        frames_per_batch (int): A keyword-only argument representing the
            total number of elements in a batch.
        total_frames (int, optional): A keyword-only argument representing the
            total number of frames returned by the collector
            during its lifespan. If the ``total_frames`` is not divisible by
            ``frames_per_batch``, an exception is raised.
            Endless collectors can be created by passing ``total_frames=-1``.
            Defaults to ``-1`` (never ending collector).
        device (int, str or torch.device, optional): The generic device of the
            collector. The ``device`` args fills any non-specified device: if
            ``device`` is not ``None`` and any of ``storing_device``, ``policy_device`` or
            ``env_device`` is not specified, its value will be set to ``device``.
            Defaults to ``None`` (No default device).
            Supports a list of devices if one wishes to indicate a different device
            for each worker. The list must be as long as the number of workers.
        storing_device (int, str or torch.device, optional): The device on which
            the output :class:`~tensordict.TensorDict` will be stored.
            If ``device`` is passed and ``storing_device`` is ``None``, it will
            default to the value indicated by ``device``.
            For long trajectories, it may be necessary to store the data on a different
            device than the one where the policy and env are executed.
            Defaults to ``None`` (the output tensordict isn't on a specific device,
            leaf tensors sit on the device where they were created).
            Supports a list of devices if one wishes to indicate a different device
            for each worker. The list must be as long as the number of workers.
        env_device (int, str or torch.device, optional): The device on which
            the environment should be cast (or executed if that functionality is
            supported). If not specified and the env has a non-``None`` device,
            ``env_device`` will default to that value. If ``device`` is passed
            and ``env_device=None``, it will default to ``device``. If the value
            as such specified of ``env_device`` differs from ``policy_device``
            and one of them is not ``None``, the data will be cast to ``env_device``
            before being passed to the env (i.e., passing different devices to
            policy and env is supported). Defaults to ``None``.
            Supports a list of devices if one wishes to indicate a different device
            for each worker. The list must be as long as the number of workers.
        policy_device (int, str or torch.device, optional): The device on which
            the policy should be cast.
            If ``device`` is passed and ``policy_device=None``, it will default
            to ``device``. If the value as such specified of ``policy_device``
            differs from ``env_device`` and one of them is not ``None``,
            the data will be cast to ``policy_device`` before being passed to
            the policy (i.e., passing different devices to policy and env is
            supported). Defaults to ``None``.
            Supports a list of devices if one wishes to indicate a different device
            for each worker. The list must be as long as the number of workers.
        create_env_kwargs (dict, optional): A dictionary with the
            keyword arguments used to create an environment. If a list is
            provided, each of its elements will be assigned to a sub-collector.
        max_frames_per_traj (int, optional): Maximum steps per trajectory.
            Note that a trajectory can span across multiple batches (unless
            ``reset_at_each_iter`` is set to ``True``, see below).
            Once a trajectory reaches ``n_steps``, the environment is reset.
            If the environment wraps multiple environments together, the number
            of steps is tracked for each environment independently. Negative
            values are allowed, in which case this argument is ignored.
            Defaults to ``None`` (i.e. no maximum number of steps).
        init_random_frames (int, optional): Number of frames for which the
            policy is ignored before it is called. This feature is mainly
            intended to be used in offline/model-based settings, where a
            batch of random trajectories can be used to initialize training.
            If provided, it will be rounded up to the closest multiple of frames_per_batch.
            Defaults to ``None`` (i.e. no random frames).
        reset_at_each_iter (bool, optional): Whether environments should be reset
            at the beginning of a batch collection.
            Defaults to ``False``.
        postproc (Callable, optional): A post-processing transform, such as
            a :class:`~torchrl.envs.Transform` or a :class:`~torchrl.data.postprocs.MultiStep`
            instance.
            Defaults to ``None``.
        split_trajs (bool, optional): Boolean indicating whether the resulting
            TensorDict should be split according to the trajectories.
            See :func:`~torchrl.collectors.utils.split_trajectories` for more
            information.
            Defaults to ``False``.
        exploration_type (ExplorationType, optional): interaction mode to be used when
            collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.DETERMINISTIC``,
            ``torchrl.envs.utils.ExplorationType.RANDOM``, ``torchrl.envs.utils.ExplorationType.MODE``
            or ``torchrl.envs.utils.ExplorationType.MEAN``.
        reset_when_done (bool, optional): if ``True`` (default), an environment
            that return a ``True`` value in its ``"done"`` or ``"truncated"``
            entry will be reset at the corresponding indices.
        update_at_each_batch (boolm optional): if ``True``, :meth:`update_policy_weights_()`
            will be called before (sync) or after (async) each data collection.
            Defaults to ``False``.
        preemptive_threshold (:obj:`float`, optional): a value between 0.0 and 1.0 that specifies the ratio of workers
            that will be allowed to finished collecting their rollout before the rest are forced to end early.
        num_threads (int, optional): number of threads for this process.
            Defaults to the number of workers.
        num_sub_threads (int, optional): number of threads of the subprocesses.
            Should be equal to one plus the number of processes launched within
            each subprocess (or one if a single process is launched).
            Defaults to 1 for safety: if none is indicated, launching multiple
            workers may charge the cpu load too much and harm performance.
        set_truncated (bool, optional): if ``True``, the truncated signals (and corresponding
            ``"done"`` but not ``"terminated"``) will be set to ``True`` when the last frame of
            a rollout is reached. If no ``"truncated"`` key is found, an exception is raised.
            Truncated keys can be set through ``env.add_truncated_keys``.
            Defaults to ``False``.
        track_policy_version (bool or PolicyVersion, optional): if ``True``, the collector will track the version of the policy.
            This will be mediated by the :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` transform, which will be added to the environment.
            Alternatively, a :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` instance can be passed, which will be used to track
            the policy version.
            Defaults to `False`.

    """

    def __init__(
        self,
        create_env_fn: Callable[[], EnvBase],
        policy: None
        | (TensorDictModule | Callable[[TensorDictBase], TensorDictBase]) = None,
        *,
        policy_factory: Callable[[], Callable] | None = None,
        frames_per_batch: int,
        total_frames: int | None = -1,
        device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None,
        storing_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None,
        env_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None,
        policy_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None,
        create_env_kwargs: Sequence[dict[str, Any]] | None = None,
        max_frames_per_traj: int | None = None,
        init_random_frames: int | None = None,
        reset_at_each_iter: bool = False,
        postproc: Callable[[TensorDictBase], TensorDictBase] | None = None,
        split_trajs: bool | None = None,
        exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE,
        reset_when_done: bool = True,
        update_at_each_batch: bool = False,
        preemptive_threshold: float | None = None,
        num_threads: int | None = None,
        num_sub_threads: int = 1,
        set_truncated: bool = False,
        track_policy_version: bool = False,
        **kwargs,
    ):
        super().__init__(
            create_env_fn=[create_env_fn],
            policy=policy,
            policy_factory=policy_factory,
            total_frames=total_frames,
            create_env_kwargs=[create_env_kwargs]
            if create_env_kwargs
            else create_env_kwargs,
            max_frames_per_traj=max_frames_per_traj,
            frames_per_batch=frames_per_batch,
            reset_at_each_iter=reset_at_each_iter,
            init_random_frames=init_random_frames,
            postproc=postproc,
            split_trajs=split_trajs,
            device=device,
            policy_device=policy_device,
            env_device=env_device,
            storing_device=storing_device,
            exploration_type=exploration_type,
            reset_when_done=reset_when_done,
            update_at_each_batch=update_at_each_batch,
            preemptive_threshold=preemptive_threshold,
            num_threads=num_threads,
            num_sub_threads=num_sub_threads,
            set_truncated=set_truncated,
            track_policy_version=track_policy_version,
            **kwargs,
        )

    # for RPC
    def next(self):
        return super().next()

    # for RPC
    def shutdown(
        self,
        timeout: float | None = None,
        close_env: bool = True,
        raise_on_error: bool = True,
    ) -> None:
        return super().shutdown(
            timeout=timeout, close_env=close_env, raise_on_error=raise_on_error
        )

    # for RPC
    def set_seed(self, seed: int, static_seed: bool = False) -> int:
        return super().set_seed(seed, static_seed)

    # for RPC
    def state_dict(self) -> OrderedDict:
        return super().state_dict()

    # for RPC
    def load_state_dict(self, state_dict: OrderedDict) -> None:
        return super().load_state_dict(state_dict)


def _main_async_collector(
    pipe_parent: connection.Connection,
    pipe_child: connection.Connection,
    queue_out: queues.Queue,
    create_env_fn: EnvBase | EnvCreator | Callable[[], EnvBase],  # noqa: F821
    create_env_kwargs: dict[str, Any],
    policy: Callable[[TensorDictBase], TensorDictBase],
    max_frames_per_traj: int,
    frames_per_batch: int,
    reset_at_each_iter: bool,
    storing_device: torch.device | str | int | None,
    env_device: torch.device | str | int | None,
    policy_device: torch.device | str | int | None,
    idx: int = 0,
    exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE,
    reset_when_done: bool = True,
    verbose: bool = VERBOSE,
    interruptor=None,
    set_truncated: bool = False,
    use_buffers: bool | None = None,
    replay_buffer: ReplayBuffer | None = None,
    extend_buffer: bool = True,
    traj_pool: _TrajectoryPool = None,
    trust_policy: bool = False,
    compile_policy: bool = False,
    cudagraph_policy: bool = False,
    no_cuda_sync: bool = False,
    policy_factory: Callable | None = None,
    collector_class: type | Callable[[], DataCollectorBase] | None = None,
    postproc: Callable[[TensorDictBase], TensorDictBase] | None = None,
    weight_sync_schemes: dict[str, WeightSyncScheme] | None = None,
) -> None:
    if collector_class is None:
        collector_class = SyncDataCollector
    pipe_parent.close()
    # init variables that will be cleared when closing
    collected_tensordict = data = next_data = data_in = inner_collector = dc_iter = None

    try:
        collector_class._ignore_rb = extend_buffer
        inner_collector = collector_class(
            create_env_fn,
            create_env_kwargs=create_env_kwargs,
            policy=policy,
            policy_factory=policy_factory,
            total_frames=-1,
            max_frames_per_traj=max_frames_per_traj,
            frames_per_batch=frames_per_batch,
            reset_at_each_iter=reset_at_each_iter,
            postproc=postproc,
            split_trajs=False,
            storing_device=storing_device,
            policy_device=policy_device,
            env_device=env_device,
            exploration_type=exploration_type,
            reset_when_done=reset_when_done,
            return_same_td=replay_buffer is None,
            interruptor=interruptor,
            set_truncated=set_truncated,
            use_buffers=use_buffers,
            replay_buffer=replay_buffer,
            extend_buffer=False,
            traj_pool=traj_pool,
            trust_policy=trust_policy,
            compile_policy=compile_policy,
            cudagraph_policy=cudagraph_policy,
            no_cuda_sync=no_cuda_sync,
            weight_sync_schemes=weight_sync_schemes,
        )

        # Set up weight receivers for worker process
        if weight_sync_schemes:
            inner_collector._weight_receivers = {}
            inner_collector.pipe = pipe_child  # Add pipe attribute for context
            for model_id, scheme in weight_sync_schemes.items():
                # Check if scheme has new API or legacy API
                if hasattr(scheme, "init_on_worker"):
                    scheme.init_on_worker(model_id=model_id, context=inner_collector)
                    receiver = scheme.get_receiver()
                else:
                    # Legacy API
                    receiver = scheme.create_receiver()
                    receiver.set_context(inner_collector)
                    receiver.register_worker_transport(pipe_child)

                    model = _resolve_model(inner_collector, model_id)
                    receiver.register_model(model)

                inner_collector._weight_receivers[model_id] = receiver
        else:
            inner_collector._weight_receivers = {}

        use_buffers = inner_collector._use_buffers
        if verbose:
            torchrl_logger.info("Sync data collector created")
        dc_iter = iter(inner_collector)
        j = 0
        pipe_child.send("instantiated")
    except Exception as e:
        # Send error information to main process
        # We send a dict with the exception info so we can recreate it in the main process
        import traceback

        error_info = {
            "error": True,
            "exception_type": type(e).__name__,
            "exception_module": type(e).__module__,
            "exception_msg": str(e),
            "traceback": traceback.format_exc(),
        }
        try:
            pipe_child.send(error_info)
        except Exception:
            # If pipe is broken, nothing we can do
            pass
        return

    has_timed_out = False
    counter = 0
    run_free = False
    while True:
        _timeout = _TIMEOUT if not has_timed_out else 1e-3
        if not run_free and pipe_child.poll(_timeout):
            counter = 0
            data_in, msg = pipe_child.recv()
            if verbose:
                torchrl_logger.info(f"worker {idx} received {msg}")
        elif not run_free:
            if verbose:
                torchrl_logger.info(f"poll failed, j={j}, worker={idx}")
            # default is "continue" (after first iteration)
            # this is expected to happen if queue_out reached the timeout, but no new msg was waiting in the pipe
            # in that case, the main process probably expects the worker to continue collect data
            if has_timed_out:
                counter = 0
                # has_timed_out is True if the process failed to send data, which will
                # typically occur if main has taken another batch (i.e. the queue is Full).
                # In this case, msg is the previous msg sent by main, which will typically be "continue"
                # If it's not the case, it is not expected that has_timed_out is True.
                if msg not in ("continue", "continue_random"):
                    raise RuntimeError(f"Unexpected message after time out: msg={msg}")
            else:
                # if has_timed_out is False, then the time out does not come from the fact that the queue is Full.
                # this means that our process has been waiting for a command from main in vain, while main was not
                # receiving data.
                # This will occur if main is busy doing something else (e.g. computing loss etc).

                counter += _timeout
                if verbose:
                    torchrl_logger.info(f"worker {idx} has counter {counter}")
                if counter >= (_MAX_IDLE_COUNT * _TIMEOUT):
                    raise RuntimeError(
                        f"This process waited for {counter} seconds "
                        f"without receiving a command from main. Consider increasing the maximum idle count "
                        f"if this is expected via the environment variable MAX_IDLE_COUNT "
                        f"(current value is {_MAX_IDLE_COUNT})."
                        f"\nIf this occurs at the end of a function or program, it means that your collector has not been "
                        f"collected, consider calling `collector.shutdown()` before ending the program."
                    )
                continue
        else:
            # placeholder, will be checked after
            if msg != "continue":
                torchrl_logger.info(f"worker {idx} will reset {msg} to 'continue'")
            msg = "continue"
        if msg == "run_free":
            run_free = True
            msg = "continue"
        if run_free:
            # Capture shutdown / update / seed signal, but continue should not be expected
            if pipe_child.poll(1e-4):
                data_in, msg = pipe_child.recv()
                torchrl_logger.info(f"worker {idx} received {msg} while running free")
                if msg == "continue":
                    # Switch back to run_free = False
                    run_free = False
                if msg == "pause":
                    queue_out.put((idx, "paused"), timeout=_TIMEOUT)
                    while not pipe_child.poll(1e-2):
                        continue
                    data_in, msg = pipe_child.recv()
                    if msg != "restart":
                        raise RuntimeError(f"Expected msg='restart', got {msg=}")
                    msg = "continue"
            else:
                data_in = None
                # TODO: this does not work with random frames
                msg = "continue"
        # Note: The "continue" message handling has been moved below after update_weights handling
        # to allow falling through from update_weights to continue

        if msg == "update":
            torchrl_logger.info(f"worker {idx} updating the params...")
            inner_collector.update_policy_weights_(policy_weights=data_in)
            pipe_child.send((j, "updated"))
            has_timed_out = False
            continue

        if msg == "register_shared_weights":
            # Shared memory lazy registration: main process sends buffer reference
            if verbose:
                torchrl_logger.info(
                    f"worker {idx} received shared memory buffer registration"
                )
            model_id, shared_buffer = data_in

            # Store the shared buffer reference for this model
            # The receiver will use this buffer for all future weight accesses
            if (
                inner_collector._weight_receivers
                and model_id in inner_collector._weight_receivers
            ):
                # Update receiver's buffer reference
                receiver = inner_collector._weight_receivers[model_id]
                # Store the shared buffer - the model's parameters should point to this
                if hasattr(receiver, "_shared_weights"):
                    receiver._shared_weights[model_id] = shared_buffer

                # Apply the buffer to the model immediately
                # Only apply if the model is an nn.Module (has learnable parameters)
                try:
                    model = receiver._resolve_model_ref()
                except (ValueError, AttributeError) as e:
                    # Model not registered or reference is invalid
                    if verbose:
                        torchrl_logger.warning(
                            f"worker {idx} could not resolve model '{model_id}': {e}"
                        )
                    continue

                if isinstance(model, nn.Module):
                    receiver.apply_weights(shared_buffer)
                else:
                    if verbose:
                        torchrl_logger.info(
                            f"worker {idx} skipping weight application for non-nn.Module model '{model_id}'"
                        )

                if verbose:
                    torchrl_logger.info(
                        f"worker {idx} registered shared buffer for model '{model_id}'"
                    )
            else:
                torchrl_logger.warning(
                    f"worker {idx} received shared buffer for unknown model '{model_id}'"
                )

            # Send acknowledgment back to main process
            pipe_child.send((None, "registered"))
            has_timed_out = False
            continue

        if msg == "update_weights":
            # New weight update protocol for simplified weight sync system
            if verbose:
                torchrl_logger.info(
                    f"worker {idx} received weight update via new protocol"
                )
            model_id, weights = data_in

            # Apply weights using the appropriate receiver for this model
            if (
                inner_collector._weight_receivers
                and model_id in inner_collector._weight_receivers
            ):
                inner_collector._weight_receivers[model_id].apply_weights(weights)
            else:
                torchrl_logger.warning(
                    f"worker {idx} received weights for unknown model '{model_id}'"
                )

            # After applying weights, we continue collecting immediately as if we received
            # a "continue" message. This ensures the worker keeps collecting data without
            # waiting for an explicit continue from the main process.
            has_timed_out = False
            msg = "continue"
            # Now check if we should continue collecting

        if msg in ("continue", "continue_random"):
            # This block handles both explicit continue messages and implicit ones after weight updates
            if msg == "continue_random":
                inner_collector.init_random_frames = float("inf")
            else:
                inner_collector.init_random_frames = -1

            # Note: For MultiProcessWeightSyncScheme, weight updates are handled by the
            # main message loop above (msg == "update_weights" case). The receiver.receive()
            # pattern is only used for schemes with separate communication channels like
            # SharedMemWeightSyncScheme (shared memory) or DistributedWeightSyncScheme (TCPStore).
            # Calling receiver.receive() here would interfere with the pipe-based message protocol.

            next_data = next(dc_iter)
            if pipe_child.poll(_MIN_TIMEOUT):
                # in this case, main send a message to the worker while it was busy collecting trajectories.
                # In that case, we skip the collected trajectory and get the message from main. This is faster than
                # sending the trajectory in the queue until timeout when it's never going to be received.
                continue

            if replay_buffer is not None:
                if extend_buffer:
                    next_data.names = None
                    replay_buffer.extend(next_data)

                if run_free:
                    continue

                try:
                    queue_out.put((idx, j), timeout=_TIMEOUT)
                    if verbose:
                        torchrl_logger.info(f"worker {idx} successfully sent data")
                    j += 1
                    has_timed_out = False
                    continue
                except queue.Full:
                    if verbose:
                        torchrl_logger.info(f"worker {idx} has timed out")
                    has_timed_out = True
                    continue

            if j == 0 or not use_buffers:
                collected_tensordict = next_data
                if (
                    storing_device is not None
                    and collected_tensordict.device != storing_device
                ):
                    raise RuntimeError(
                        f"expected device to be {storing_device} but got {collected_tensordict.device}"
                    )
                if use_buffers:
                    # If policy and env are on cpu, we put in shared mem,
                    # if policy is on cuda and env on cuda, we are fine with this
                    # If policy is on cuda and env on cpu (or opposite) we put tensors that
                    # are on cpu in shared mem.
                    MPS_ERROR = (
                        "tensors on mps device cannot be put in shared memory. Make sure "
                        "the shared device (aka storing_device) is set to CPU."
                    )
                    if collected_tensordict.device is not None:
                        # placeholder in case we need different behaviors
                        if collected_tensordict.device.type in ("cpu",):
                            collected_tensordict.share_memory_()
                        elif collected_tensordict.device.type in ("mps",):
                            raise RuntimeError(MPS_ERROR)
                        elif collected_tensordict.device.type == "cuda":
                            collected_tensordict.share_memory_()
                        else:
                            raise NotImplementedError(
                                f"Device {collected_tensordict.device} is not supported in multi-collectors yet."
                            )
                    else:
                        # make sure each cpu tensor is shared - assuming non-cpu devices are shared
                        def cast_tensor(x, MPS_ERROR=MPS_ERROR):
                            if x.device.type in ("cpu",):
                                x.share_memory_()
                            if x.device.type in ("mps",):
                                RuntimeError(MPS_ERROR)

                        collected_tensordict.apply(cast_tensor, filter_empty=True)
                data = (collected_tensordict, idx)
            else:
                if next_data is not collected_tensordict:
                    raise RuntimeError(
                        "SyncDataCollector should return the same tensordict modified in-place."
                    )
                data = idx  # flag the worker that has sent its data
            try:
                queue_out.put((data, j), timeout=_TIMEOUT)
                if verbose:
                    torchrl_logger.info(f"worker {idx} successfully sent data")
                j += 1
                has_timed_out = False
                continue
            except queue.Full:
                if verbose:
                    torchrl_logger.info(f"worker {idx} has timed out")
                has_timed_out = True
                continue

        if msg == "seed":
            data_in, static_seed = data_in
            new_seed = inner_collector.set_seed(data_in, static_seed=static_seed)
            torch.manual_seed(data_in)
            np.random.seed(data_in)
            pipe_child.send((new_seed, "seeded"))
            has_timed_out = False
            continue

        elif msg == "reset":
            inner_collector.reset()
            pipe_child.send((j, "reset"))
            continue

        elif msg == "state_dict":
            state_dict = inner_collector.state_dict()
            # send state_dict to cpu first
            state_dict = recursive_map_to_cpu(state_dict)
            pipe_child.send((state_dict, "state_dict"))
            has_timed_out = False
            continue

        elif msg == "load_state_dict":
            state_dict = data_in
            inner_collector.load_state_dict(state_dict)
            del state_dict
            pipe_child.send((j, "loaded"))
            has_timed_out = False
            continue

        elif msg == "getattr_policy":
            attr_name = data_in
            try:
                result = getattr(inner_collector.policy, attr_name)
                pipe_child.send((result, "getattr_policy"))
            except AttributeError as e:
                pipe_child.send((e, "getattr_policy"))
            has_timed_out = False
            continue

        elif msg == "getattr_env":
            attr_name = data_in
            try:
                result = getattr(inner_collector.env, attr_name)
                pipe_child.send((result, "getattr_env"))
            except AttributeError as e:
                pipe_child.send((e, "getattr_env"))
            has_timed_out = False
            continue

        elif msg == "close":
            del collected_tensordict, data, next_data, data_in
            inner_collector.shutdown()
            del inner_collector, dc_iter
            pipe_child.send("closed")
            if verbose:
                torchrl_logger.info(f"collector {idx} closed")
            break

        else:
            raise Exception(f"Unrecognized message {msg}")


def _make_meta_params(param):
    is_param = isinstance(param, Parameter)

    pd = param.detach().to("meta")

    if is_param:
        pd = Parameter(pd, requires_grad=False)
    return pd


class _TrajectoryPool:
    def __init__(self, ctx=None, lock: bool = False):
        self.ctx = ctx
        self._traj_id = torch.zeros((), device="cpu", dtype=torch.int).share_memory_()
        if ctx is None:
            self.lock = contextlib.nullcontext() if not lock else mp.RLock()
        else:
            self.lock = contextlib.nullcontext() if not lock else ctx.RLock()

    def get_traj_and_increment(self, n=1, device=None):
        with self.lock:
            v = self._traj_id.item()
            out = torch.arange(v, v + n).to(device)
            self._traj_id.copy_(1 + out[-1].item())
        return out


def _map_weight(
    weight,
    policy_device,
):

    is_param = isinstance(weight, Parameter)
    is_buffer = isinstance(weight, Buffer)
    weight = weight.data
    if weight.device != policy_device:
        weight = weight.to(policy_device)
    elif weight.device.type in ("cpu",):
        weight = weight.share_memory_()
    if is_param:
        weight = Parameter(weight, requires_grad=False)
    elif is_buffer:
        weight = Buffer(weight)
    return weight
